From 335d79ad5489f1c8951710a5d549ba558fa92b13 Mon Sep 17 00:00:00 2001 From: Michael Lazar Date: Tue, 12 May 2020 23:50:12 -0400 Subject: [PATCH] Move server framework from socketserver to twisted :) --- jetforce.py | 374 +++++++++++++++++++++++++++++++--------------------- setup.py | 2 +- 2 files changed, 223 insertions(+), 153 deletions(-) diff --git a/jetforce.py b/jetforce.py index bb0e852..12a37f2 100755 --- a/jetforce.py +++ b/jetforce.py @@ -35,6 +35,7 @@ StaticDirectoryApplication: from __future__ import annotations import argparse +import base64 import codecs import dataclasses import datetime @@ -43,7 +44,6 @@ import os import pathlib import re import socket -import socketserver import subprocess import sys import tempfile @@ -56,6 +56,16 @@ from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import rsa from OpenSSL import SSL +from twisted.internet import reactor +from twisted.internet.base import ReactorBase +from twisted.internet.endpoints import SSL4ServerEndpoint +from twisted.internet.protocol import Factory +from twisted.internet.ssl import CertificateOptions +from twisted.internet.tcp import Port +from twisted.protocols.basic import LineOnlyReceiver + +CN = x509.NameOID.COMMON_NAME + if sys.version_info < (3, 7): sys.exit("Fatal Error: jetforce requires Python 3.7+") @@ -264,6 +274,49 @@ class RoutePattern: return re.fullmatch(self.path, request_path) +def generate_ad_hoc_certificate(hostname: str) -> typing.Tuple[str, str]: + """ + Utility function to generate an ad-hoc self-signed SSL certificate. + """ + certfile = os.path.join(tempfile.gettempdir(), f"{hostname}.crt") + keyfile = os.path.join(tempfile.gettempdir(), f"{hostname}.key") + + if not os.path.exists(certfile) or not os.path.exists(keyfile): + backend = default_backend() + + print("Generating private key...", file=sys.stderr) + private_key = rsa.generate_private_key(65537, 2048, default_backend()) + with open(keyfile, "wb") as fp: + # noinspection PyTypeChecker + key_data = private_key.private_bytes( + serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + fp.write(key_data) + + print("Generating certificate...", file=sys.stderr) + common_name = x509.NameAttribute(CN, hostname) + subject_name = x509.Name([common_name]) + not_valid_before = datetime.datetime.utcnow() + not_valid_after = not_valid_before + datetime.timedelta(days=365) + certificate = x509.CertificateBuilder( + subject_name=subject_name, + issuer_name=subject_name, + public_key=private_key.public_key(), + serial_number=x509.random_serial_number(), + not_valid_before=not_valid_before, + not_valid_after=not_valid_after, + ) + certificate = certificate.sign(private_key, hashes.SHA256(), backend) + with open(certfile, "wb") as fp: + # noinspection PyTypeChecker + cert_data = certificate.public_bytes(serialization.Encoding.PEM) + fp.write(cert_data) + + return certfile, keyfile + + class JetforceApplication: """ Base Jetforce application class with primitive URL routing. @@ -514,7 +567,46 @@ class StaticDirectoryApplication(JetforceApplication): return Response(Status.NOT_FOUND, "Not Found") -class GeminiRequestHandler(socketserver.StreamRequestHandler): +class GeminiTLSContextFactory: + """ + Generate a sane default SSL context for a Gemini server. + """ + + def __init__( + self, + hostname: str = "localhost", + certfile: typing.Optional[str] = None, + keyfile: typing.Optional[str] = None, + cafile: typing.Optional[str] = None, + capath: typing.Optional[str] = None, + ): + if certfile is None: + certfile, keyfile = generate_ad_hoc_certificate(hostname) + + context = SSL.Context(SSL.TLSv1_2_METHOD) + context.use_certificate_file(certfile) + context.use_privatekey_file(keyfile or certfile) + context.check_privatekey() + if cafile or capath: + context.load_verify_locations(cafile, capath) + context.set_verify(SSL.VERIFY_PEER, self.verify_cb) + self.context = context + + def getContext(self) -> SSL.Context: + """ + Return the SSL context, this method must be implemented for twisted. + """ + return self.context + + def verify_cb(self, connection, x509, err_no, err_depth, return_code): + """ + Disable all peer certificate validation at the openSSL level in order + to allow self-signed client certificates. + """ + return True + + +class GeminiProtocol(LineOnlyReceiver): """ Handle a single Gemini Protocol TCP request. @@ -534,26 +626,49 @@ class GeminiRequestHandler(socketserver.StreamRequestHandler): TIMESTAMP_FORMAT = "%d/%b/%Y:%H:%M:%S %z" - server: GeminiServer - received_timestamp: time.struct_time + connected_timestamp: time.struct_time + request: bytes url: str status: int meta: str response_buffer: str response_size: int - def setup(self) -> None: + def __init__(self, server: GeminiServer, app: JetforceApplication): + self.server = server + self.app = app + + def connectionMade(self): + """ + This is invoked by twisted after the connection is first established. + """ + self.connected_timestamp = time.localtime() self.response_size = 0 self.response_buffer = "" - super().setup() - def handle(self) -> None: + def lineReceived(self, line): """ - The request handler entry point, called once for each connection. - """ - self.received_timestamp = time.localtime() - self.request.do_handshake() + This method is invoked by LineOnlyReceiver for every incoming line. + Because Gemini requests are only ever a single line long, this will + only be called once and we can use it to handle the lifetime of the + connection without managing any state. + """ + self.request = line + try: + try: + self.handle() + finally: + self.flush_status() + try: + self.log_request() + except Exception: + # Malformed request or dropped connection + pass + finally: + self.transport.loseConnection() + + def handle(self): try: self.parse_header() except Exception: @@ -563,7 +678,7 @@ class GeminiRequestHandler(socketserver.StreamRequestHandler): try: environ = self.build_environ() - response_generator = self.server.app(environ, self.write_status) + response_generator = self.app(environ, self.write_status) for data in response_generator: self.write_body(data) except Exception: @@ -576,30 +691,39 @@ class GeminiRequestHandler(socketserver.StreamRequestHandler): Variable names conform to the CGI spec defined in RFC 3875. """ url_parts = urllib.parse.urlparse(self.url) - remote_addr, remote_port, *_ = self.client_address + client_addr = self.transport.getPeer() environ = { "GEMINI_URL": self.url, "HOSTNAME": self.server.hostname, "PATH_INFO": url_parts.path, "QUERY_STRING": url_parts.query, - "REMOTE_ADDR": remote_addr, - "REMOTE_HOST": remote_addr, + "REMOTE_ADDR": client_addr.host, + "REMOTE_HOST": client_addr.host, "SERVER_NAME": self.server.hostname, - "SERVER_PORT": str(remote_port), + "SERVER_PORT": str(client_addr.port), "SERVER_PROTOCOL": "GEMINI", "SERVER_SOFTWARE": f"jetforce/{__version__}", } - client_cert = self.request.getpeercert() - if client_cert: - subject = dict(x[0] for x in client_cert["subject"]) + openssl_cert = self.transport.getPeerCertificate() + if openssl_cert: + # Extract useful information from the client certificate. These + # mostly follow the naming convention from GLV-1.12556 + cert = openssl_cert.to_cryptography() + name_attrs = cert.subject.get_attributes_for_oid(CN) + common_name = name_attrs[0].value if name_attrs else "" + fingerprint_bytes = cert.fingerprint(hashes.SHA256()) + fingerprint = base64.b64encode(fingerprint_bytes).decode() + not_before = cert.not_valid_before.strftime("%Y-%m-%dT%H:%M:%SZ") + not_after = cert.not_valid_after.strftime("%Y-%m-%dT%H:%M:%SZ") environ.update( { "AUTH_TYPE": "CERTIFICATE", - "REMOTE_USER": subject.get("commonName", ""), - "TLS_CLIENT_NOT_BEFORE": client_cert["notBefore"], - "TLS_CLIENT_NOT_AFTER": client_cert["notAfter"], - "TLS_CLIENT_SERIAL_NUMBER": client_cert["serialNumber"], + "REMOTE_USER": common_name, + "TLS_CLIENT_HASH": fingerprint, + "TLS_CLIENT_NOT_BEFORE": not_before, + "TLS_CLIENT_NOT_AFTER": not_after, + "TLS_CLIENT_SERIAL_NUMBER": cert.serial_number, } ) @@ -611,12 +735,10 @@ class GeminiRequestHandler(socketserver.StreamRequestHandler): The request is a single UTF-8 line formatted as: \r\n """ - data = self.rfile.readline(1026) - data = data.rstrip(b"\r\n") - if len(data) > 1024: + if len(self.request) > 1024: raise ValueError("URL exceeds max length of 1024 bytes") - self.url = data.decode() + self.url = self.request.decode() def write_status(self, status: int, meta: str) -> None: """ @@ -643,7 +765,7 @@ class GeminiRequestHandler(socketserver.StreamRequestHandler): """ self.flush_status() self.response_size += len(data) - self.wfile.write(data) + self.transport.write(data) def flush_status(self) -> None: """ @@ -652,162 +774,113 @@ class GeminiRequestHandler(socketserver.StreamRequestHandler): if self.response_buffer and not self.response_size: data = self.response_buffer.encode() self.response_size += len(data) - self.wfile.write(data) + self.transport.write(data) self.response_buffer = "" - def finish(self) -> None: - self.flush_status() - try: - self.log_request() - except AttributeError: - # Malformed request or dropped connection - pass - super().finish() - def log_request(self) -> None: """ Log a gemini request using a format derived from the Common Log Format. """ - self.server.log_message( - f"{self.client_address[0]} " - f"[{time.strftime(self.TIMESTAMP_FORMAT, self.received_timestamp)}] " - f'"{self.url}" ' - f"{self.status} " - f'"{self.meta}" ' - f"{self.response_size}" + message = '{} [{}] "{}" {} {} {}'.format( + self.transport.getPeer().host, + time.strftime(self.TIMESTAMP_FORMAT, self.connected_timestamp), + self.url, + self.status, + self.meta, + self.response_size, ) + self.server.log_message(message) -class GeminiServer(socketserver.ThreadingTCPServer): +class GeminiServer(Factory): """ - An asynchronous TCP server that uses the asyncio stream abstraction. + This class acts as a wrapper around most of the plumbing for twisted. - This is a lightweight class that accepts incoming requests, logs them, and - sends them to a configurable request handler to be processed. + There's not much going on here, the main intention is to make it as simple + as possible to import and run a server without needing to understand the + complicated class hierarchy and conventions defined by twisted. """ - request_handler_class = GeminiRequestHandler + # Initializes the pyOpenSSL context object, you may want to override this + # to customize your server's TLS configuration. + tls_context_factory_class = GeminiTLSContextFactory + + # Request handler class, you probably don't want to override this. + protocol_class = GeminiProtocol + + # The TLS twisted interface class is confusingly named SSL4, even though it + # will accept either IPv4 & IPv6 interfaces. + endpoint_class = SSL4ServerEndpoint def __init__( self, app: typing.Callable, + reactor: ReactorBase = reactor, host: str = "127.0.0.1", port: int = 1965, - ssl_context: SSL.Context = None, hostname: str = "localhost", - ) -> None: - + certfile: typing.Optional[str] = None, + keyfile: typing.Optional[str] = None, + cafile: typing.Optional[str] = None, + capath: typing.Optional[str] = None, + **_, + ): self.app = app + self.reactor = reactor + self.host = host + self.port = port self.hostname = hostname - self.ssl_context = ssl_context - super().__init__((host, port), self.request_handler_class, False) + self.certfile = certfile + self.keyfile = keyfile + self.cafile = cafile + self.capath = capath - def run(self) -> None: + def log_message(self, message: str) -> None: """ - Launch the main server loop. + Log a diagnostic server message to stderr. """ - self.log_message(ABOUT) - self.log_message(f"Server hostname is {self.hostname}") - try: - self.server_bind() - self.server_activate() - except Exception: - self.server_close() - raise + print(message, file=sys.stderr) - sock_ip, sock_port, *_ = self.server_address - if self.address_family == socket.AF_INET: + def on_bind_interface(self, port: Port) -> None: + """ + Log when the server binds to an interface. + """ + sock_ip, sock_port, *_ = port.socket.getsockname() + if port.addressFamily == socket.AF_INET: self.log_message(f"Listening on {sock_ip}:{sock_port}") else: self.log_message(f"Listening on [{sock_ip}]:{sock_port}") - self.serve_forever() - - def get_request(self) -> typing.Tuple[SSL.Connection, typing.Tuple[str, int]]: + def buildProtocol(self, addr) -> GeminiProtocol: """ - Wrap the incoming request in an SSL connection. + This method is invoked by twisted once for every incoming connection. + + It builds the protocol instance which acts as a request handler and + implements the actual Gemini protocol. """ - # noinspection PyTupleAssignmentBalance - sock, client_addr = super(GeminiServer, self).get_request() - sock = SSL.Connection(self.ssl_context, sock) - return sock, client_addr + return GeminiProtocol(self, self.app) - def log_message(self, message: str) -> None: + def run(self) -> None: """ - Log a diagnostic server message to stderr, may be overridden. + This is the main server loop. """ - print(message, file=sys.stderr) - - -def generate_ad_hoc_certificate(hostname: str) -> typing.Tuple[str, str]: - """ - Utility function to generate an ad-hoc self-signed SSL certificate. - """ - certfile = os.path.join(tempfile.gettempdir(), f"{hostname}.crt") - keyfile = os.path.join(tempfile.gettempdir(), f"{hostname}.key") - - if not os.path.exists(certfile) or not os.path.exists(keyfile): - backend = default_backend() - - print("Generating private key...", file=sys.stderr) - private_key = rsa.generate_private_key(65537, 2048, default_backend()) - with open(keyfile, "wb") as fp: - # noinspection PyTypeChecker - key_data = private_key.private_bytes( - serialization.Encoding.PEM, - format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.NoEncryption(), - ) - fp.write(key_data) - - print("Generating certificate...", file=sys.stderr) - common_name = x509.NameAttribute(x509.NameOID.COMMON_NAME, hostname) - subject_name = x509.Name([common_name]) - not_valid_before = datetime.datetime.utcnow() - not_valid_after = not_valid_before + datetime.timedelta(days=365) - certificate = x509.CertificateBuilder( - subject_name=subject_name, - issuer_name=subject_name, - public_key=private_key.public_key(), - serial_number=x509.random_serial_number(), - not_valid_before=not_valid_before, - not_valid_after=not_valid_after, + self.log_message(ABOUT) + self.log_message(f"Server hostname is {self.hostname}") + tls_context_factory = self.tls_context_factory_class( + hostname=self.hostname, + certfile=self.certfile, + keyfile=self.keyfile, + cafile=self.cafile, + capath=self.capath, ) - certificate = certificate.sign(private_key, hashes.SHA256(), backend) - with open(certfile, "wb") as fp: - # noinspection PyTypeChecker - cert_data = certificate.public_bytes(serialization.Encoding.PEM) - fp.write(cert_data) - - return certfile, keyfile - - -def make_ssl_context( - hostname: str = "localhost", - certfile: typing.Optional[str] = None, - keyfile: typing.Optional[str] = None, - cafile: typing.Optional[str] = None, - capath: typing.Optional[str] = None, -) -> SSL.Context: - """ - Generate a sane default SSL context for a Gemini server. - """ - if certfile is None: - certfile, keyfile = generate_ad_hoc_certificate(hostname) - - context = SSL.Context(SSL.TLSv1_2_METHOD) - context.use_certificate_file(certfile) - context.use_privatekey_file(keyfile or certfile) - context.check_privatekey() - if cafile or capath: - context.load_verify_locations(cafile, capath) - - def verify_cb(connection, x509, err_no, err_depth, return_code): - pass - - context.set_verify(SSL.VERIFY_PEER, verify_cb) - - return context + endpoint = self.endpoint_class( + reactor=self.reactor, + port=self.port, + sslContextFactory=tls_context_factory, + interface=self.host, + ) + endpoint.listen(self).addCallback(self.on_bind_interface) + self.reactor.run() def run_server() -> None: @@ -816,10 +889,7 @@ def run_server() -> None: """ args = parser.parse_args() app = StaticDirectoryApplication(args.dir, args.index_file, args.cgi_dir) - ssl_context = make_ssl_context( - args.hostname, args.certfile, args.keyfile, args.cafile, args.capath - ) - server = GeminiServer(app, args.host, args.port, ssl_context, args.hostname) + server = GeminiServer(app, **vars(args)) server.run() diff --git a/setup.py b/setup.py index 883a26d..30275dd 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ setuptools.setup( author="Michael Lazar", author_email="lazar.michael22@gmail.com", description="An Experimental Gemini Server", - install_requires=["cryptography", "pyopenssl"], + install_requires=["cryptography", "pyopenssl", "twisted"], long_description=long_description(), long_description_content_type="text/markdown", py_modules=["jetforce", "jetforce_client", "jetforce_diagnostics"],