Move server framework from socketserver to twisted :)

This commit is contained in:
Michael Lazar 2020-05-12 23:50:12 -04:00
parent ef4023bd5c
commit 335d79ad54
2 changed files with 223 additions and 153 deletions

View File

@ -35,6 +35,7 @@ StaticDirectoryApplication:
from __future__ import annotations from __future__ import annotations
import argparse import argparse
import base64
import codecs import codecs
import dataclasses import dataclasses
import datetime import datetime
@ -43,7 +44,6 @@ import os
import pathlib import pathlib
import re import re
import socket import socket
import socketserver
import subprocess import subprocess
import sys import sys
import tempfile import tempfile
@ -56,6 +56,16 @@ from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.asymmetric import rsa
from OpenSSL import SSL from OpenSSL import SSL
from twisted.internet import reactor
from twisted.internet.base import ReactorBase
from twisted.internet.endpoints import SSL4ServerEndpoint
from twisted.internet.protocol import Factory
from twisted.internet.ssl import CertificateOptions
from twisted.internet.tcp import Port
from twisted.protocols.basic import LineOnlyReceiver
CN = x509.NameOID.COMMON_NAME
if sys.version_info < (3, 7): if sys.version_info < (3, 7):
sys.exit("Fatal Error: jetforce requires Python 3.7+") sys.exit("Fatal Error: jetforce requires Python 3.7+")
@ -264,6 +274,49 @@ class RoutePattern:
return re.fullmatch(self.path, request_path) return re.fullmatch(self.path, request_path)
def generate_ad_hoc_certificate(hostname: str) -> typing.Tuple[str, str]:
"""
Utility function to generate an ad-hoc self-signed SSL certificate.
"""
certfile = os.path.join(tempfile.gettempdir(), f"{hostname}.crt")
keyfile = os.path.join(tempfile.gettempdir(), f"{hostname}.key")
if not os.path.exists(certfile) or not os.path.exists(keyfile):
backend = default_backend()
print("Generating private key...", file=sys.stderr)
private_key = rsa.generate_private_key(65537, 2048, default_backend())
with open(keyfile, "wb") as fp:
# noinspection PyTypeChecker
key_data = private_key.private_bytes(
serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
)
fp.write(key_data)
print("Generating certificate...", file=sys.stderr)
common_name = x509.NameAttribute(CN, hostname)
subject_name = x509.Name([common_name])
not_valid_before = datetime.datetime.utcnow()
not_valid_after = not_valid_before + datetime.timedelta(days=365)
certificate = x509.CertificateBuilder(
subject_name=subject_name,
issuer_name=subject_name,
public_key=private_key.public_key(),
serial_number=x509.random_serial_number(),
not_valid_before=not_valid_before,
not_valid_after=not_valid_after,
)
certificate = certificate.sign(private_key, hashes.SHA256(), backend)
with open(certfile, "wb") as fp:
# noinspection PyTypeChecker
cert_data = certificate.public_bytes(serialization.Encoding.PEM)
fp.write(cert_data)
return certfile, keyfile
class JetforceApplication: class JetforceApplication:
""" """
Base Jetforce application class with primitive URL routing. Base Jetforce application class with primitive URL routing.
@ -514,7 +567,46 @@ class StaticDirectoryApplication(JetforceApplication):
return Response(Status.NOT_FOUND, "Not Found") return Response(Status.NOT_FOUND, "Not Found")
class GeminiRequestHandler(socketserver.StreamRequestHandler): class GeminiTLSContextFactory:
"""
Generate a sane default SSL context for a Gemini server.
"""
def __init__(
self,
hostname: str = "localhost",
certfile: typing.Optional[str] = None,
keyfile: typing.Optional[str] = None,
cafile: typing.Optional[str] = None,
capath: typing.Optional[str] = None,
):
if certfile is None:
certfile, keyfile = generate_ad_hoc_certificate(hostname)
context = SSL.Context(SSL.TLSv1_2_METHOD)
context.use_certificate_file(certfile)
context.use_privatekey_file(keyfile or certfile)
context.check_privatekey()
if cafile or capath:
context.load_verify_locations(cafile, capath)
context.set_verify(SSL.VERIFY_PEER, self.verify_cb)
self.context = context
def getContext(self) -> SSL.Context:
"""
Return the SSL context, this method must be implemented for twisted.
"""
return self.context
def verify_cb(self, connection, x509, err_no, err_depth, return_code):
"""
Disable all peer certificate validation at the openSSL level in order
to allow self-signed client certificates.
"""
return True
class GeminiProtocol(LineOnlyReceiver):
""" """
Handle a single Gemini Protocol TCP request. Handle a single Gemini Protocol TCP request.
@ -534,26 +626,49 @@ class GeminiRequestHandler(socketserver.StreamRequestHandler):
TIMESTAMP_FORMAT = "%d/%b/%Y:%H:%M:%S %z" TIMESTAMP_FORMAT = "%d/%b/%Y:%H:%M:%S %z"
server: GeminiServer connected_timestamp: time.struct_time
received_timestamp: time.struct_time request: bytes
url: str url: str
status: int status: int
meta: str meta: str
response_buffer: str response_buffer: str
response_size: int response_size: int
def setup(self) -> None: def __init__(self, server: GeminiServer, app: JetforceApplication):
self.server = server
self.app = app
def connectionMade(self):
"""
This is invoked by twisted after the connection is first established.
"""
self.connected_timestamp = time.localtime()
self.response_size = 0 self.response_size = 0
self.response_buffer = "" self.response_buffer = ""
super().setup()
def handle(self) -> None: def lineReceived(self, line):
""" """
The request handler entry point, called once for each connection. This method is invoked by LineOnlyReceiver for every incoming line.
"""
self.received_timestamp = time.localtime()
self.request.do_handshake()
Because Gemini requests are only ever a single line long, this will
only be called once and we can use it to handle the lifetime of the
connection without managing any state.
"""
self.request = line
try:
try:
self.handle()
finally:
self.flush_status()
try:
self.log_request()
except Exception:
# Malformed request or dropped connection
pass
finally:
self.transport.loseConnection()
def handle(self):
try: try:
self.parse_header() self.parse_header()
except Exception: except Exception:
@ -563,7 +678,7 @@ class GeminiRequestHandler(socketserver.StreamRequestHandler):
try: try:
environ = self.build_environ() environ = self.build_environ()
response_generator = self.server.app(environ, self.write_status) response_generator = self.app(environ, self.write_status)
for data in response_generator: for data in response_generator:
self.write_body(data) self.write_body(data)
except Exception: except Exception:
@ -576,30 +691,39 @@ class GeminiRequestHandler(socketserver.StreamRequestHandler):
Variable names conform to the CGI spec defined in RFC 3875. Variable names conform to the CGI spec defined in RFC 3875.
""" """
url_parts = urllib.parse.urlparse(self.url) url_parts = urllib.parse.urlparse(self.url)
remote_addr, remote_port, *_ = self.client_address client_addr = self.transport.getPeer()
environ = { environ = {
"GEMINI_URL": self.url, "GEMINI_URL": self.url,
"HOSTNAME": self.server.hostname, "HOSTNAME": self.server.hostname,
"PATH_INFO": url_parts.path, "PATH_INFO": url_parts.path,
"QUERY_STRING": url_parts.query, "QUERY_STRING": url_parts.query,
"REMOTE_ADDR": remote_addr, "REMOTE_ADDR": client_addr.host,
"REMOTE_HOST": remote_addr, "REMOTE_HOST": client_addr.host,
"SERVER_NAME": self.server.hostname, "SERVER_NAME": self.server.hostname,
"SERVER_PORT": str(remote_port), "SERVER_PORT": str(client_addr.port),
"SERVER_PROTOCOL": "GEMINI", "SERVER_PROTOCOL": "GEMINI",
"SERVER_SOFTWARE": f"jetforce/{__version__}", "SERVER_SOFTWARE": f"jetforce/{__version__}",
} }
client_cert = self.request.getpeercert() openssl_cert = self.transport.getPeerCertificate()
if client_cert: if openssl_cert:
subject = dict(x[0] for x in client_cert["subject"]) # Extract useful information from the client certificate. These
# mostly follow the naming convention from GLV-1.12556
cert = openssl_cert.to_cryptography()
name_attrs = cert.subject.get_attributes_for_oid(CN)
common_name = name_attrs[0].value if name_attrs else ""
fingerprint_bytes = cert.fingerprint(hashes.SHA256())
fingerprint = base64.b64encode(fingerprint_bytes).decode()
not_before = cert.not_valid_before.strftime("%Y-%m-%dT%H:%M:%SZ")
not_after = cert.not_valid_after.strftime("%Y-%m-%dT%H:%M:%SZ")
environ.update( environ.update(
{ {
"AUTH_TYPE": "CERTIFICATE", "AUTH_TYPE": "CERTIFICATE",
"REMOTE_USER": subject.get("commonName", ""), "REMOTE_USER": common_name,
"TLS_CLIENT_NOT_BEFORE": client_cert["notBefore"], "TLS_CLIENT_HASH": fingerprint,
"TLS_CLIENT_NOT_AFTER": client_cert["notAfter"], "TLS_CLIENT_NOT_BEFORE": not_before,
"TLS_CLIENT_SERIAL_NUMBER": client_cert["serialNumber"], "TLS_CLIENT_NOT_AFTER": not_after,
"TLS_CLIENT_SERIAL_NUMBER": cert.serial_number,
} }
) )
@ -611,12 +735,10 @@ class GeminiRequestHandler(socketserver.StreamRequestHandler):
The request is a single UTF-8 line formatted as: <URL>\r\n The request is a single UTF-8 line formatted as: <URL>\r\n
""" """
data = self.rfile.readline(1026) if len(self.request) > 1024:
data = data.rstrip(b"\r\n")
if len(data) > 1024:
raise ValueError("URL exceeds max length of 1024 bytes") raise ValueError("URL exceeds max length of 1024 bytes")
self.url = data.decode() self.url = self.request.decode()
def write_status(self, status: int, meta: str) -> None: def write_status(self, status: int, meta: str) -> None:
""" """
@ -643,7 +765,7 @@ class GeminiRequestHandler(socketserver.StreamRequestHandler):
""" """
self.flush_status() self.flush_status()
self.response_size += len(data) self.response_size += len(data)
self.wfile.write(data) self.transport.write(data)
def flush_status(self) -> None: def flush_status(self) -> None:
""" """
@ -652,162 +774,113 @@ class GeminiRequestHandler(socketserver.StreamRequestHandler):
if self.response_buffer and not self.response_size: if self.response_buffer and not self.response_size:
data = self.response_buffer.encode() data = self.response_buffer.encode()
self.response_size += len(data) self.response_size += len(data)
self.wfile.write(data) self.transport.write(data)
self.response_buffer = "" self.response_buffer = ""
def finish(self) -> None:
self.flush_status()
try:
self.log_request()
except AttributeError:
# Malformed request or dropped connection
pass
super().finish()
def log_request(self) -> None: def log_request(self) -> None:
""" """
Log a gemini request using a format derived from the Common Log Format. Log a gemini request using a format derived from the Common Log Format.
""" """
self.server.log_message( message = '{} [{}] "{}" {} {} {}'.format(
f"{self.client_address[0]} " self.transport.getPeer().host,
f"[{time.strftime(self.TIMESTAMP_FORMAT, self.received_timestamp)}] " time.strftime(self.TIMESTAMP_FORMAT, self.connected_timestamp),
f'"{self.url}" ' self.url,
f"{self.status} " self.status,
f'"{self.meta}" ' self.meta,
f"{self.response_size}" self.response_size,
) )
self.server.log_message(message)
class GeminiServer(socketserver.ThreadingTCPServer): class GeminiServer(Factory):
""" """
An asynchronous TCP server that uses the asyncio stream abstraction. This class acts as a wrapper around most of the plumbing for twisted.
This is a lightweight class that accepts incoming requests, logs them, and There's not much going on here, the main intention is to make it as simple
sends them to a configurable request handler to be processed. as possible to import and run a server without needing to understand the
complicated class hierarchy and conventions defined by twisted.
""" """
request_handler_class = GeminiRequestHandler # Initializes the pyOpenSSL context object, you may want to override this
# to customize your server's TLS configuration.
tls_context_factory_class = GeminiTLSContextFactory
# Request handler class, you probably don't want to override this.
protocol_class = GeminiProtocol
# The TLS twisted interface class is confusingly named SSL4, even though it
# will accept either IPv4 & IPv6 interfaces.
endpoint_class = SSL4ServerEndpoint
def __init__( def __init__(
self, self,
app: typing.Callable, app: typing.Callable,
reactor: ReactorBase = reactor,
host: str = "127.0.0.1", host: str = "127.0.0.1",
port: int = 1965, port: int = 1965,
ssl_context: SSL.Context = None,
hostname: str = "localhost",
) -> None:
self.app = app
self.hostname = hostname
self.ssl_context = ssl_context
super().__init__((host, port), self.request_handler_class, False)
def run(self) -> None:
"""
Launch the main server loop.
"""
self.log_message(ABOUT)
self.log_message(f"Server hostname is {self.hostname}")
try:
self.server_bind()
self.server_activate()
except Exception:
self.server_close()
raise
sock_ip, sock_port, *_ = self.server_address
if self.address_family == socket.AF_INET:
self.log_message(f"Listening on {sock_ip}:{sock_port}")
else:
self.log_message(f"Listening on [{sock_ip}]:{sock_port}")
self.serve_forever()
def get_request(self) -> typing.Tuple[SSL.Connection, typing.Tuple[str, int]]:
"""
Wrap the incoming request in an SSL connection.
"""
# noinspection PyTupleAssignmentBalance
sock, client_addr = super(GeminiServer, self).get_request()
sock = SSL.Connection(self.ssl_context, sock)
return sock, client_addr
def log_message(self, message: str) -> None:
"""
Log a diagnostic server message to stderr, may be overridden.
"""
print(message, file=sys.stderr)
def generate_ad_hoc_certificate(hostname: str) -> typing.Tuple[str, str]:
"""
Utility function to generate an ad-hoc self-signed SSL certificate.
"""
certfile = os.path.join(tempfile.gettempdir(), f"{hostname}.crt")
keyfile = os.path.join(tempfile.gettempdir(), f"{hostname}.key")
if not os.path.exists(certfile) or not os.path.exists(keyfile):
backend = default_backend()
print("Generating private key...", file=sys.stderr)
private_key = rsa.generate_private_key(65537, 2048, default_backend())
with open(keyfile, "wb") as fp:
# noinspection PyTypeChecker
key_data = private_key.private_bytes(
serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
)
fp.write(key_data)
print("Generating certificate...", file=sys.stderr)
common_name = x509.NameAttribute(x509.NameOID.COMMON_NAME, hostname)
subject_name = x509.Name([common_name])
not_valid_before = datetime.datetime.utcnow()
not_valid_after = not_valid_before + datetime.timedelta(days=365)
certificate = x509.CertificateBuilder(
subject_name=subject_name,
issuer_name=subject_name,
public_key=private_key.public_key(),
serial_number=x509.random_serial_number(),
not_valid_before=not_valid_before,
not_valid_after=not_valid_after,
)
certificate = certificate.sign(private_key, hashes.SHA256(), backend)
with open(certfile, "wb") as fp:
# noinspection PyTypeChecker
cert_data = certificate.public_bytes(serialization.Encoding.PEM)
fp.write(cert_data)
return certfile, keyfile
def make_ssl_context(
hostname: str = "localhost", hostname: str = "localhost",
certfile: typing.Optional[str] = None, certfile: typing.Optional[str] = None,
keyfile: typing.Optional[str] = None, keyfile: typing.Optional[str] = None,
cafile: typing.Optional[str] = None, cafile: typing.Optional[str] = None,
capath: typing.Optional[str] = None, capath: typing.Optional[str] = None,
) -> SSL.Context: **_,
):
self.app = app
self.reactor = reactor
self.host = host
self.port = port
self.hostname = hostname
self.certfile = certfile
self.keyfile = keyfile
self.cafile = cafile
self.capath = capath
def log_message(self, message: str) -> None:
""" """
Generate a sane default SSL context for a Gemini server. Log a diagnostic server message to stderr.
""" """
if certfile is None: print(message, file=sys.stderr)
certfile, keyfile = generate_ad_hoc_certificate(hostname)
context = SSL.Context(SSL.TLSv1_2_METHOD) def on_bind_interface(self, port: Port) -> None:
context.use_certificate_file(certfile) """
context.use_privatekey_file(keyfile or certfile) Log when the server binds to an interface.
context.check_privatekey() """
if cafile or capath: sock_ip, sock_port, *_ = port.socket.getsockname()
context.load_verify_locations(cafile, capath) if port.addressFamily == socket.AF_INET:
self.log_message(f"Listening on {sock_ip}:{sock_port}")
else:
self.log_message(f"Listening on [{sock_ip}]:{sock_port}")
def verify_cb(connection, x509, err_no, err_depth, return_code): def buildProtocol(self, addr) -> GeminiProtocol:
pass """
This method is invoked by twisted once for every incoming connection.
context.set_verify(SSL.VERIFY_PEER, verify_cb) It builds the protocol instance which acts as a request handler and
implements the actual Gemini protocol.
"""
return GeminiProtocol(self, self.app)
return context def run(self) -> None:
"""
This is the main server loop.
"""
self.log_message(ABOUT)
self.log_message(f"Server hostname is {self.hostname}")
tls_context_factory = self.tls_context_factory_class(
hostname=self.hostname,
certfile=self.certfile,
keyfile=self.keyfile,
cafile=self.cafile,
capath=self.capath,
)
endpoint = self.endpoint_class(
reactor=self.reactor,
port=self.port,
sslContextFactory=tls_context_factory,
interface=self.host,
)
endpoint.listen(self).addCallback(self.on_bind_interface)
self.reactor.run()
def run_server() -> None: def run_server() -> None:
@ -816,10 +889,7 @@ def run_server() -> None:
""" """
args = parser.parse_args() args = parser.parse_args()
app = StaticDirectoryApplication(args.dir, args.index_file, args.cgi_dir) app = StaticDirectoryApplication(args.dir, args.index_file, args.cgi_dir)
ssl_context = make_ssl_context( server = GeminiServer(app, **vars(args))
args.hostname, args.certfile, args.keyfile, args.cafile, args.capath
)
server = GeminiServer(app, args.host, args.port, ssl_context, args.hostname)
server.run() server.run()

View File

@ -16,7 +16,7 @@ setuptools.setup(
author="Michael Lazar", author="Michael Lazar",
author_email="lazar.michael22@gmail.com", author_email="lazar.michael22@gmail.com",
description="An Experimental Gemini Server", description="An Experimental Gemini Server",
install_requires=["cryptography", "pyopenssl"], install_requires=["cryptography", "pyopenssl", "twisted"],
long_description=long_description(), long_description=long_description(),
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
py_modules=["jetforce", "jetforce_client", "jetforce_diagnostics"], py_modules=["jetforce", "jetforce_client", "jetforce_diagnostics"],