Move server framework from socketserver to twisted :)

This commit is contained in:
Michael Lazar 2020-05-12 23:50:12 -04:00
parent ef4023bd5c
commit 335d79ad54
2 changed files with 223 additions and 153 deletions

View File

@ -35,6 +35,7 @@ StaticDirectoryApplication:
from __future__ import annotations
import argparse
import base64
import codecs
import dataclasses
import datetime
@ -43,7 +44,6 @@ import os
import pathlib
import re
import socket
import socketserver
import subprocess
import sys
import tempfile
@ -56,6 +56,16 @@ from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from OpenSSL import SSL
from twisted.internet import reactor
from twisted.internet.base import ReactorBase
from twisted.internet.endpoints import SSL4ServerEndpoint
from twisted.internet.protocol import Factory
from twisted.internet.ssl import CertificateOptions
from twisted.internet.tcp import Port
from twisted.protocols.basic import LineOnlyReceiver
CN = x509.NameOID.COMMON_NAME
if sys.version_info < (3, 7):
sys.exit("Fatal Error: jetforce requires Python 3.7+")
@ -264,6 +274,49 @@ class RoutePattern:
return re.fullmatch(self.path, request_path)
def generate_ad_hoc_certificate(hostname: str) -> typing.Tuple[str, str]:
"""
Utility function to generate an ad-hoc self-signed SSL certificate.
"""
certfile = os.path.join(tempfile.gettempdir(), f"{hostname}.crt")
keyfile = os.path.join(tempfile.gettempdir(), f"{hostname}.key")
if not os.path.exists(certfile) or not os.path.exists(keyfile):
backend = default_backend()
print("Generating private key...", file=sys.stderr)
private_key = rsa.generate_private_key(65537, 2048, default_backend())
with open(keyfile, "wb") as fp:
# noinspection PyTypeChecker
key_data = private_key.private_bytes(
serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
)
fp.write(key_data)
print("Generating certificate...", file=sys.stderr)
common_name = x509.NameAttribute(CN, hostname)
subject_name = x509.Name([common_name])
not_valid_before = datetime.datetime.utcnow()
not_valid_after = not_valid_before + datetime.timedelta(days=365)
certificate = x509.CertificateBuilder(
subject_name=subject_name,
issuer_name=subject_name,
public_key=private_key.public_key(),
serial_number=x509.random_serial_number(),
not_valid_before=not_valid_before,
not_valid_after=not_valid_after,
)
certificate = certificate.sign(private_key, hashes.SHA256(), backend)
with open(certfile, "wb") as fp:
# noinspection PyTypeChecker
cert_data = certificate.public_bytes(serialization.Encoding.PEM)
fp.write(cert_data)
return certfile, keyfile
class JetforceApplication:
"""
Base Jetforce application class with primitive URL routing.
@ -514,7 +567,46 @@ class StaticDirectoryApplication(JetforceApplication):
return Response(Status.NOT_FOUND, "Not Found")
class GeminiRequestHandler(socketserver.StreamRequestHandler):
class GeminiTLSContextFactory:
"""
Generate a sane default SSL context for a Gemini server.
"""
def __init__(
self,
hostname: str = "localhost",
certfile: typing.Optional[str] = None,
keyfile: typing.Optional[str] = None,
cafile: typing.Optional[str] = None,
capath: typing.Optional[str] = None,
):
if certfile is None:
certfile, keyfile = generate_ad_hoc_certificate(hostname)
context = SSL.Context(SSL.TLSv1_2_METHOD)
context.use_certificate_file(certfile)
context.use_privatekey_file(keyfile or certfile)
context.check_privatekey()
if cafile or capath:
context.load_verify_locations(cafile, capath)
context.set_verify(SSL.VERIFY_PEER, self.verify_cb)
self.context = context
def getContext(self) -> SSL.Context:
"""
Return the SSL context, this method must be implemented for twisted.
"""
return self.context
def verify_cb(self, connection, x509, err_no, err_depth, return_code):
"""
Disable all peer certificate validation at the openSSL level in order
to allow self-signed client certificates.
"""
return True
class GeminiProtocol(LineOnlyReceiver):
"""
Handle a single Gemini Protocol TCP request.
@ -534,26 +626,49 @@ class GeminiRequestHandler(socketserver.StreamRequestHandler):
TIMESTAMP_FORMAT = "%d/%b/%Y:%H:%M:%S %z"
server: GeminiServer
received_timestamp: time.struct_time
connected_timestamp: time.struct_time
request: bytes
url: str
status: int
meta: str
response_buffer: str
response_size: int
def setup(self) -> None:
def __init__(self, server: GeminiServer, app: JetforceApplication):
self.server = server
self.app = app
def connectionMade(self):
"""
This is invoked by twisted after the connection is first established.
"""
self.connected_timestamp = time.localtime()
self.response_size = 0
self.response_buffer = ""
super().setup()
def handle(self) -> None:
def lineReceived(self, line):
"""
The request handler entry point, called once for each connection.
"""
self.received_timestamp = time.localtime()
self.request.do_handshake()
This method is invoked by LineOnlyReceiver for every incoming line.
Because Gemini requests are only ever a single line long, this will
only be called once and we can use it to handle the lifetime of the
connection without managing any state.
"""
self.request = line
try:
try:
self.handle()
finally:
self.flush_status()
try:
self.log_request()
except Exception:
# Malformed request or dropped connection
pass
finally:
self.transport.loseConnection()
def handle(self):
try:
self.parse_header()
except Exception:
@ -563,7 +678,7 @@ class GeminiRequestHandler(socketserver.StreamRequestHandler):
try:
environ = self.build_environ()
response_generator = self.server.app(environ, self.write_status)
response_generator = self.app(environ, self.write_status)
for data in response_generator:
self.write_body(data)
except Exception:
@ -576,30 +691,39 @@ class GeminiRequestHandler(socketserver.StreamRequestHandler):
Variable names conform to the CGI spec defined in RFC 3875.
"""
url_parts = urllib.parse.urlparse(self.url)
remote_addr, remote_port, *_ = self.client_address
client_addr = self.transport.getPeer()
environ = {
"GEMINI_URL": self.url,
"HOSTNAME": self.server.hostname,
"PATH_INFO": url_parts.path,
"QUERY_STRING": url_parts.query,
"REMOTE_ADDR": remote_addr,
"REMOTE_HOST": remote_addr,
"REMOTE_ADDR": client_addr.host,
"REMOTE_HOST": client_addr.host,
"SERVER_NAME": self.server.hostname,
"SERVER_PORT": str(remote_port),
"SERVER_PORT": str(client_addr.port),
"SERVER_PROTOCOL": "GEMINI",
"SERVER_SOFTWARE": f"jetforce/{__version__}",
}
client_cert = self.request.getpeercert()
if client_cert:
subject = dict(x[0] for x in client_cert["subject"])
openssl_cert = self.transport.getPeerCertificate()
if openssl_cert:
# Extract useful information from the client certificate. These
# mostly follow the naming convention from GLV-1.12556
cert = openssl_cert.to_cryptography()
name_attrs = cert.subject.get_attributes_for_oid(CN)
common_name = name_attrs[0].value if name_attrs else ""
fingerprint_bytes = cert.fingerprint(hashes.SHA256())
fingerprint = base64.b64encode(fingerprint_bytes).decode()
not_before = cert.not_valid_before.strftime("%Y-%m-%dT%H:%M:%SZ")
not_after = cert.not_valid_after.strftime("%Y-%m-%dT%H:%M:%SZ")
environ.update(
{
"AUTH_TYPE": "CERTIFICATE",
"REMOTE_USER": subject.get("commonName", ""),
"TLS_CLIENT_NOT_BEFORE": client_cert["notBefore"],
"TLS_CLIENT_NOT_AFTER": client_cert["notAfter"],
"TLS_CLIENT_SERIAL_NUMBER": client_cert["serialNumber"],
"REMOTE_USER": common_name,
"TLS_CLIENT_HASH": fingerprint,
"TLS_CLIENT_NOT_BEFORE": not_before,
"TLS_CLIENT_NOT_AFTER": not_after,
"TLS_CLIENT_SERIAL_NUMBER": cert.serial_number,
}
)
@ -611,12 +735,10 @@ class GeminiRequestHandler(socketserver.StreamRequestHandler):
The request is a single UTF-8 line formatted as: <URL>\r\n
"""
data = self.rfile.readline(1026)
data = data.rstrip(b"\r\n")
if len(data) > 1024:
if len(self.request) > 1024:
raise ValueError("URL exceeds max length of 1024 bytes")
self.url = data.decode()
self.url = self.request.decode()
def write_status(self, status: int, meta: str) -> None:
"""
@ -643,7 +765,7 @@ class GeminiRequestHandler(socketserver.StreamRequestHandler):
"""
self.flush_status()
self.response_size += len(data)
self.wfile.write(data)
self.transport.write(data)
def flush_status(self) -> None:
"""
@ -652,162 +774,113 @@ class GeminiRequestHandler(socketserver.StreamRequestHandler):
if self.response_buffer and not self.response_size:
data = self.response_buffer.encode()
self.response_size += len(data)
self.wfile.write(data)
self.transport.write(data)
self.response_buffer = ""
def finish(self) -> None:
self.flush_status()
try:
self.log_request()
except AttributeError:
# Malformed request or dropped connection
pass
super().finish()
def log_request(self) -> None:
"""
Log a gemini request using a format derived from the Common Log Format.
"""
self.server.log_message(
f"{self.client_address[0]} "
f"[{time.strftime(self.TIMESTAMP_FORMAT, self.received_timestamp)}] "
f'"{self.url}" '
f"{self.status} "
f'"{self.meta}" '
f"{self.response_size}"
message = '{} [{}] "{}" {} {} {}'.format(
self.transport.getPeer().host,
time.strftime(self.TIMESTAMP_FORMAT, self.connected_timestamp),
self.url,
self.status,
self.meta,
self.response_size,
)
self.server.log_message(message)
class GeminiServer(socketserver.ThreadingTCPServer):
class GeminiServer(Factory):
"""
An asynchronous TCP server that uses the asyncio stream abstraction.
This class acts as a wrapper around most of the plumbing for twisted.
This is a lightweight class that accepts incoming requests, logs them, and
sends them to a configurable request handler to be processed.
There's not much going on here, the main intention is to make it as simple
as possible to import and run a server without needing to understand the
complicated class hierarchy and conventions defined by twisted.
"""
request_handler_class = GeminiRequestHandler
# Initializes the pyOpenSSL context object, you may want to override this
# to customize your server's TLS configuration.
tls_context_factory_class = GeminiTLSContextFactory
# Request handler class, you probably don't want to override this.
protocol_class = GeminiProtocol
# The TLS twisted interface class is confusingly named SSL4, even though it
# will accept either IPv4 & IPv6 interfaces.
endpoint_class = SSL4ServerEndpoint
def __init__(
self,
app: typing.Callable,
reactor: ReactorBase = reactor,
host: str = "127.0.0.1",
port: int = 1965,
ssl_context: SSL.Context = None,
hostname: str = "localhost",
) -> None:
certfile: typing.Optional[str] = None,
keyfile: typing.Optional[str] = None,
cafile: typing.Optional[str] = None,
capath: typing.Optional[str] = None,
**_,
):
self.app = app
self.reactor = reactor
self.host = host
self.port = port
self.hostname = hostname
self.ssl_context = ssl_context
super().__init__((host, port), self.request_handler_class, False)
self.certfile = certfile
self.keyfile = keyfile
self.cafile = cafile
self.capath = capath
def run(self) -> None:
def log_message(self, message: str) -> None:
"""
Launch the main server loop.
Log a diagnostic server message to stderr.
"""
self.log_message(ABOUT)
self.log_message(f"Server hostname is {self.hostname}")
try:
self.server_bind()
self.server_activate()
except Exception:
self.server_close()
raise
print(message, file=sys.stderr)
sock_ip, sock_port, *_ = self.server_address
if self.address_family == socket.AF_INET:
def on_bind_interface(self, port: Port) -> None:
"""
Log when the server binds to an interface.
"""
sock_ip, sock_port, *_ = port.socket.getsockname()
if port.addressFamily == socket.AF_INET:
self.log_message(f"Listening on {sock_ip}:{sock_port}")
else:
self.log_message(f"Listening on [{sock_ip}]:{sock_port}")
self.serve_forever()
def get_request(self) -> typing.Tuple[SSL.Connection, typing.Tuple[str, int]]:
def buildProtocol(self, addr) -> GeminiProtocol:
"""
Wrap the incoming request in an SSL connection.
This method is invoked by twisted once for every incoming connection.
It builds the protocol instance which acts as a request handler and
implements the actual Gemini protocol.
"""
# noinspection PyTupleAssignmentBalance
sock, client_addr = super(GeminiServer, self).get_request()
sock = SSL.Connection(self.ssl_context, sock)
return sock, client_addr
return GeminiProtocol(self, self.app)
def log_message(self, message: str) -> None:
def run(self) -> None:
"""
Log a diagnostic server message to stderr, may be overridden.
This is the main server loop.
"""
print(message, file=sys.stderr)
def generate_ad_hoc_certificate(hostname: str) -> typing.Tuple[str, str]:
"""
Utility function to generate an ad-hoc self-signed SSL certificate.
"""
certfile = os.path.join(tempfile.gettempdir(), f"{hostname}.crt")
keyfile = os.path.join(tempfile.gettempdir(), f"{hostname}.key")
if not os.path.exists(certfile) or not os.path.exists(keyfile):
backend = default_backend()
print("Generating private key...", file=sys.stderr)
private_key = rsa.generate_private_key(65537, 2048, default_backend())
with open(keyfile, "wb") as fp:
# noinspection PyTypeChecker
key_data = private_key.private_bytes(
serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
)
fp.write(key_data)
print("Generating certificate...", file=sys.stderr)
common_name = x509.NameAttribute(x509.NameOID.COMMON_NAME, hostname)
subject_name = x509.Name([common_name])
not_valid_before = datetime.datetime.utcnow()
not_valid_after = not_valid_before + datetime.timedelta(days=365)
certificate = x509.CertificateBuilder(
subject_name=subject_name,
issuer_name=subject_name,
public_key=private_key.public_key(),
serial_number=x509.random_serial_number(),
not_valid_before=not_valid_before,
not_valid_after=not_valid_after,
self.log_message(ABOUT)
self.log_message(f"Server hostname is {self.hostname}")
tls_context_factory = self.tls_context_factory_class(
hostname=self.hostname,
certfile=self.certfile,
keyfile=self.keyfile,
cafile=self.cafile,
capath=self.capath,
)
certificate = certificate.sign(private_key, hashes.SHA256(), backend)
with open(certfile, "wb") as fp:
# noinspection PyTypeChecker
cert_data = certificate.public_bytes(serialization.Encoding.PEM)
fp.write(cert_data)
return certfile, keyfile
def make_ssl_context(
hostname: str = "localhost",
certfile: typing.Optional[str] = None,
keyfile: typing.Optional[str] = None,
cafile: typing.Optional[str] = None,
capath: typing.Optional[str] = None,
) -> SSL.Context:
"""
Generate a sane default SSL context for a Gemini server.
"""
if certfile is None:
certfile, keyfile = generate_ad_hoc_certificate(hostname)
context = SSL.Context(SSL.TLSv1_2_METHOD)
context.use_certificate_file(certfile)
context.use_privatekey_file(keyfile or certfile)
context.check_privatekey()
if cafile or capath:
context.load_verify_locations(cafile, capath)
def verify_cb(connection, x509, err_no, err_depth, return_code):
pass
context.set_verify(SSL.VERIFY_PEER, verify_cb)
return context
endpoint = self.endpoint_class(
reactor=self.reactor,
port=self.port,
sslContextFactory=tls_context_factory,
interface=self.host,
)
endpoint.listen(self).addCallback(self.on_bind_interface)
self.reactor.run()
def run_server() -> None:
@ -816,10 +889,7 @@ def run_server() -> None:
"""
args = parser.parse_args()
app = StaticDirectoryApplication(args.dir, args.index_file, args.cgi_dir)
ssl_context = make_ssl_context(
args.hostname, args.certfile, args.keyfile, args.cafile, args.capath
)
server = GeminiServer(app, args.host, args.port, ssl_context, args.hostname)
server = GeminiServer(app, **vars(args))
server.run()

View File

@ -16,7 +16,7 @@ setuptools.setup(
author="Michael Lazar",
author_email="lazar.michael22@gmail.com",
description="An Experimental Gemini Server",
install_requires=["cryptography", "pyopenssl"],
install_requires=["cryptography", "pyopenssl", "twisted"],
long_description=long_description(),
long_description_content_type="text/markdown",
py_modules=["jetforce", "jetforce_client", "jetforce_diagnostics"],