diff --git a/jetforce/tls.py b/jetforce/tls.py index aea659d..ca801a4 100644 --- a/jetforce/tls.py +++ b/jetforce/tls.py @@ -94,7 +94,14 @@ class GeminiCertificateOptions(CertificateOptions): https://github.com/twisted/twisted/blob/trunk/src/twisted/internet/_sslverify.py """ - def verify_callback(self, conn, cert, errno, depth, preverify_ok): + def verify_callback( + self, + conn: OpenSSL.SSL.Connection, + cert: OpenSSL.crypto.X509, + errno: int, + depth: int, + preverify_ok: int, + ) -> bool: """ Callback used by OpenSSL for client certificate verification. @@ -106,7 +113,9 @@ class GeminiCertificateOptions(CertificateOptions): conn.verified = preverify_ok return True - def proto_select_callback(self, conn, protocols): + def proto_select_callback( + self, conn: OpenSSL.SSL.Connection, protocols: typing.List[bytes] + ) -> bytes: """ Callback used by OpenSSL for ALPN support. @@ -118,6 +127,16 @@ class GeminiCertificateOptions(CertificateOptions): else: return b"" + def sni_callback(self, conn: OpenSSL.SSL.Connection) -> None: + """ + Callback used by OpenSSL for SNI support. + + We can inspect the servername requested by the client using + conn.get_servername(), and attach an appropriate context using + conn.set_context(new_context). + """ + pass + def __init__( self, certfile: str, @@ -177,4 +196,6 @@ class GeminiCertificateOptions(CertificateOptions): ctx.set_alpn_select_callback(self.proto_select_callback) ctx.set_alpn_protos(self._acceptableProtocols) + ctx.set_tlsext_servername_callback(self.sni_callback) + return ctx diff --git a/jetforce_client.py b/jetforce_client.py index 1b08b56..d8b2282 100755 --- a/jetforce_client.py +++ b/jetforce_client.py @@ -15,7 +15,7 @@ context.check_hostname = False context.verify_mode = ssl.CERT_NONE -def fetch(url: str, host: str = None, port: str = None): +def fetch(url, host=None, port=None, use_sni=False): parsed_url = urllib.parse.urlparse(url) if not parsed_url.scheme: parsed_url = urllib.parse.urlparse(f"gemini://{url}") @@ -23,8 +23,10 @@ def fetch(url: str, host: str = None, port: str = None): host = host or parsed_url.hostname port = port or parsed_url.port or 1965 + server_hostname = host if use_sni else None + with socket.create_connection((host, port)) as sock: - with context.wrap_socket(sock) as ssock: + with context.wrap_socket(sock, server_hostname=server_hostname) as ssock: ssock.sendall((url + "\r\n").encode()) fp = ssock.makefile("rb", buffering=0) data = fp.read(1024) @@ -44,12 +46,18 @@ def run_client(): ) parser.add_argument("--certfile", help="Optional client certificate") parser.add_argument("--keyfile", help="Optional client key") - args = parser.parse_args() + parser.add_argument("--alpn-protocol", help="Indicate the protocol using ALPN") + parser.add_argument( + "--use-sni", action="store_true", help="Specify the server hostname via SNI" + ) + args = parser.parse_args() if args.certfile: context.load_cert_chain(args.certfile, args.keyfile) + if args.alpn_protocol: + context.set_alpn_protocols([args.alpn_protocol]) - fetch(args.url, args.host, args.port) + fetch(args.url, args.host, args.port, args.use_sni) if __name__ == "__main__":