Move server framework from socketserver to twisted :)
This commit is contained in:
parent
ef4023bd5c
commit
335d79ad54
374
jetforce.py
374
jetforce.py
|
@ -35,6 +35,7 @@ StaticDirectoryApplication:
|
|||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import codecs
|
||||
import dataclasses
|
||||
import datetime
|
||||
|
@ -43,7 +44,6 @@ import os
|
|||
import pathlib
|
||||
import re
|
||||
import socket
|
||||
import socketserver
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
|
@ -56,6 +56,16 @@ from cryptography.hazmat.backends import default_backend
|
|||
from cryptography.hazmat.primitives import hashes, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from OpenSSL import SSL
|
||||
from twisted.internet import reactor
|
||||
from twisted.internet.base import ReactorBase
|
||||
from twisted.internet.endpoints import SSL4ServerEndpoint
|
||||
from twisted.internet.protocol import Factory
|
||||
from twisted.internet.ssl import CertificateOptions
|
||||
from twisted.internet.tcp import Port
|
||||
from twisted.protocols.basic import LineOnlyReceiver
|
||||
|
||||
CN = x509.NameOID.COMMON_NAME
|
||||
|
||||
|
||||
if sys.version_info < (3, 7):
|
||||
sys.exit("Fatal Error: jetforce requires Python 3.7+")
|
||||
|
@ -264,6 +274,49 @@ class RoutePattern:
|
|||
return re.fullmatch(self.path, request_path)
|
||||
|
||||
|
||||
def generate_ad_hoc_certificate(hostname: str) -> typing.Tuple[str, str]:
|
||||
"""
|
||||
Utility function to generate an ad-hoc self-signed SSL certificate.
|
||||
"""
|
||||
certfile = os.path.join(tempfile.gettempdir(), f"{hostname}.crt")
|
||||
keyfile = os.path.join(tempfile.gettempdir(), f"{hostname}.key")
|
||||
|
||||
if not os.path.exists(certfile) or not os.path.exists(keyfile):
|
||||
backend = default_backend()
|
||||
|
||||
print("Generating private key...", file=sys.stderr)
|
||||
private_key = rsa.generate_private_key(65537, 2048, default_backend())
|
||||
with open(keyfile, "wb") as fp:
|
||||
# noinspection PyTypeChecker
|
||||
key_data = private_key.private_bytes(
|
||||
serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
fp.write(key_data)
|
||||
|
||||
print("Generating certificate...", file=sys.stderr)
|
||||
common_name = x509.NameAttribute(CN, hostname)
|
||||
subject_name = x509.Name([common_name])
|
||||
not_valid_before = datetime.datetime.utcnow()
|
||||
not_valid_after = not_valid_before + datetime.timedelta(days=365)
|
||||
certificate = x509.CertificateBuilder(
|
||||
subject_name=subject_name,
|
||||
issuer_name=subject_name,
|
||||
public_key=private_key.public_key(),
|
||||
serial_number=x509.random_serial_number(),
|
||||
not_valid_before=not_valid_before,
|
||||
not_valid_after=not_valid_after,
|
||||
)
|
||||
certificate = certificate.sign(private_key, hashes.SHA256(), backend)
|
||||
with open(certfile, "wb") as fp:
|
||||
# noinspection PyTypeChecker
|
||||
cert_data = certificate.public_bytes(serialization.Encoding.PEM)
|
||||
fp.write(cert_data)
|
||||
|
||||
return certfile, keyfile
|
||||
|
||||
|
||||
class JetforceApplication:
|
||||
"""
|
||||
Base Jetforce application class with primitive URL routing.
|
||||
|
@ -514,7 +567,46 @@ class StaticDirectoryApplication(JetforceApplication):
|
|||
return Response(Status.NOT_FOUND, "Not Found")
|
||||
|
||||
|
||||
class GeminiRequestHandler(socketserver.StreamRequestHandler):
|
||||
class GeminiTLSContextFactory:
|
||||
"""
|
||||
Generate a sane default SSL context for a Gemini server.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hostname: str = "localhost",
|
||||
certfile: typing.Optional[str] = None,
|
||||
keyfile: typing.Optional[str] = None,
|
||||
cafile: typing.Optional[str] = None,
|
||||
capath: typing.Optional[str] = None,
|
||||
):
|
||||
if certfile is None:
|
||||
certfile, keyfile = generate_ad_hoc_certificate(hostname)
|
||||
|
||||
context = SSL.Context(SSL.TLSv1_2_METHOD)
|
||||
context.use_certificate_file(certfile)
|
||||
context.use_privatekey_file(keyfile or certfile)
|
||||
context.check_privatekey()
|
||||
if cafile or capath:
|
||||
context.load_verify_locations(cafile, capath)
|
||||
context.set_verify(SSL.VERIFY_PEER, self.verify_cb)
|
||||
self.context = context
|
||||
|
||||
def getContext(self) -> SSL.Context:
|
||||
"""
|
||||
Return the SSL context, this method must be implemented for twisted.
|
||||
"""
|
||||
return self.context
|
||||
|
||||
def verify_cb(self, connection, x509, err_no, err_depth, return_code):
|
||||
"""
|
||||
Disable all peer certificate validation at the openSSL level in order
|
||||
to allow self-signed client certificates.
|
||||
"""
|
||||
return True
|
||||
|
||||
|
||||
class GeminiProtocol(LineOnlyReceiver):
|
||||
"""
|
||||
Handle a single Gemini Protocol TCP request.
|
||||
|
||||
|
@ -534,26 +626,49 @@ class GeminiRequestHandler(socketserver.StreamRequestHandler):
|
|||
|
||||
TIMESTAMP_FORMAT = "%d/%b/%Y:%H:%M:%S %z"
|
||||
|
||||
server: GeminiServer
|
||||
received_timestamp: time.struct_time
|
||||
connected_timestamp: time.struct_time
|
||||
request: bytes
|
||||
url: str
|
||||
status: int
|
||||
meta: str
|
||||
response_buffer: str
|
||||
response_size: int
|
||||
|
||||
def setup(self) -> None:
|
||||
def __init__(self, server: GeminiServer, app: JetforceApplication):
|
||||
self.server = server
|
||||
self.app = app
|
||||
|
||||
def connectionMade(self):
|
||||
"""
|
||||
This is invoked by twisted after the connection is first established.
|
||||
"""
|
||||
self.connected_timestamp = time.localtime()
|
||||
self.response_size = 0
|
||||
self.response_buffer = ""
|
||||
super().setup()
|
||||
|
||||
def handle(self) -> None:
|
||||
def lineReceived(self, line):
|
||||
"""
|
||||
The request handler entry point, called once for each connection.
|
||||
"""
|
||||
self.received_timestamp = time.localtime()
|
||||
self.request.do_handshake()
|
||||
This method is invoked by LineOnlyReceiver for every incoming line.
|
||||
|
||||
Because Gemini requests are only ever a single line long, this will
|
||||
only be called once and we can use it to handle the lifetime of the
|
||||
connection without managing any state.
|
||||
"""
|
||||
self.request = line
|
||||
try:
|
||||
try:
|
||||
self.handle()
|
||||
finally:
|
||||
self.flush_status()
|
||||
try:
|
||||
self.log_request()
|
||||
except Exception:
|
||||
# Malformed request or dropped connection
|
||||
pass
|
||||
finally:
|
||||
self.transport.loseConnection()
|
||||
|
||||
def handle(self):
|
||||
try:
|
||||
self.parse_header()
|
||||
except Exception:
|
||||
|
@ -563,7 +678,7 @@ class GeminiRequestHandler(socketserver.StreamRequestHandler):
|
|||
|
||||
try:
|
||||
environ = self.build_environ()
|
||||
response_generator = self.server.app(environ, self.write_status)
|
||||
response_generator = self.app(environ, self.write_status)
|
||||
for data in response_generator:
|
||||
self.write_body(data)
|
||||
except Exception:
|
||||
|
@ -576,30 +691,39 @@ class GeminiRequestHandler(socketserver.StreamRequestHandler):
|
|||
Variable names conform to the CGI spec defined in RFC 3875.
|
||||
"""
|
||||
url_parts = urllib.parse.urlparse(self.url)
|
||||
remote_addr, remote_port, *_ = self.client_address
|
||||
client_addr = self.transport.getPeer()
|
||||
environ = {
|
||||
"GEMINI_URL": self.url,
|
||||
"HOSTNAME": self.server.hostname,
|
||||
"PATH_INFO": url_parts.path,
|
||||
"QUERY_STRING": url_parts.query,
|
||||
"REMOTE_ADDR": remote_addr,
|
||||
"REMOTE_HOST": remote_addr,
|
||||
"REMOTE_ADDR": client_addr.host,
|
||||
"REMOTE_HOST": client_addr.host,
|
||||
"SERVER_NAME": self.server.hostname,
|
||||
"SERVER_PORT": str(remote_port),
|
||||
"SERVER_PORT": str(client_addr.port),
|
||||
"SERVER_PROTOCOL": "GEMINI",
|
||||
"SERVER_SOFTWARE": f"jetforce/{__version__}",
|
||||
}
|
||||
|
||||
client_cert = self.request.getpeercert()
|
||||
if client_cert:
|
||||
subject = dict(x[0] for x in client_cert["subject"])
|
||||
openssl_cert = self.transport.getPeerCertificate()
|
||||
if openssl_cert:
|
||||
# Extract useful information from the client certificate. These
|
||||
# mostly follow the naming convention from GLV-1.12556
|
||||
cert = openssl_cert.to_cryptography()
|
||||
name_attrs = cert.subject.get_attributes_for_oid(CN)
|
||||
common_name = name_attrs[0].value if name_attrs else ""
|
||||
fingerprint_bytes = cert.fingerprint(hashes.SHA256())
|
||||
fingerprint = base64.b64encode(fingerprint_bytes).decode()
|
||||
not_before = cert.not_valid_before.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
not_after = cert.not_valid_after.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
environ.update(
|
||||
{
|
||||
"AUTH_TYPE": "CERTIFICATE",
|
||||
"REMOTE_USER": subject.get("commonName", ""),
|
||||
"TLS_CLIENT_NOT_BEFORE": client_cert["notBefore"],
|
||||
"TLS_CLIENT_NOT_AFTER": client_cert["notAfter"],
|
||||
"TLS_CLIENT_SERIAL_NUMBER": client_cert["serialNumber"],
|
||||
"REMOTE_USER": common_name,
|
||||
"TLS_CLIENT_HASH": fingerprint,
|
||||
"TLS_CLIENT_NOT_BEFORE": not_before,
|
||||
"TLS_CLIENT_NOT_AFTER": not_after,
|
||||
"TLS_CLIENT_SERIAL_NUMBER": cert.serial_number,
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -611,12 +735,10 @@ class GeminiRequestHandler(socketserver.StreamRequestHandler):
|
|||
|
||||
The request is a single UTF-8 line formatted as: <URL>\r\n
|
||||
"""
|
||||
data = self.rfile.readline(1026)
|
||||
data = data.rstrip(b"\r\n")
|
||||
if len(data) > 1024:
|
||||
if len(self.request) > 1024:
|
||||
raise ValueError("URL exceeds max length of 1024 bytes")
|
||||
|
||||
self.url = data.decode()
|
||||
self.url = self.request.decode()
|
||||
|
||||
def write_status(self, status: int, meta: str) -> None:
|
||||
"""
|
||||
|
@ -643,7 +765,7 @@ class GeminiRequestHandler(socketserver.StreamRequestHandler):
|
|||
"""
|
||||
self.flush_status()
|
||||
self.response_size += len(data)
|
||||
self.wfile.write(data)
|
||||
self.transport.write(data)
|
||||
|
||||
def flush_status(self) -> None:
|
||||
"""
|
||||
|
@ -652,162 +774,113 @@ class GeminiRequestHandler(socketserver.StreamRequestHandler):
|
|||
if self.response_buffer and not self.response_size:
|
||||
data = self.response_buffer.encode()
|
||||
self.response_size += len(data)
|
||||
self.wfile.write(data)
|
||||
self.transport.write(data)
|
||||
self.response_buffer = ""
|
||||
|
||||
def finish(self) -> None:
|
||||
self.flush_status()
|
||||
try:
|
||||
self.log_request()
|
||||
except AttributeError:
|
||||
# Malformed request or dropped connection
|
||||
pass
|
||||
super().finish()
|
||||
|
||||
def log_request(self) -> None:
|
||||
"""
|
||||
Log a gemini request using a format derived from the Common Log Format.
|
||||
"""
|
||||
self.server.log_message(
|
||||
f"{self.client_address[0]} "
|
||||
f"[{time.strftime(self.TIMESTAMP_FORMAT, self.received_timestamp)}] "
|
||||
f'"{self.url}" '
|
||||
f"{self.status} "
|
||||
f'"{self.meta}" '
|
||||
f"{self.response_size}"
|
||||
message = '{} [{}] "{}" {} {} {}'.format(
|
||||
self.transport.getPeer().host,
|
||||
time.strftime(self.TIMESTAMP_FORMAT, self.connected_timestamp),
|
||||
self.url,
|
||||
self.status,
|
||||
self.meta,
|
||||
self.response_size,
|
||||
)
|
||||
self.server.log_message(message)
|
||||
|
||||
|
||||
class GeminiServer(socketserver.ThreadingTCPServer):
|
||||
class GeminiServer(Factory):
|
||||
"""
|
||||
An asynchronous TCP server that uses the asyncio stream abstraction.
|
||||
This class acts as a wrapper around most of the plumbing for twisted.
|
||||
|
||||
This is a lightweight class that accepts incoming requests, logs them, and
|
||||
sends them to a configurable request handler to be processed.
|
||||
There's not much going on here, the main intention is to make it as simple
|
||||
as possible to import and run a server without needing to understand the
|
||||
complicated class hierarchy and conventions defined by twisted.
|
||||
"""
|
||||
|
||||
request_handler_class = GeminiRequestHandler
|
||||
# Initializes the pyOpenSSL context object, you may want to override this
|
||||
# to customize your server's TLS configuration.
|
||||
tls_context_factory_class = GeminiTLSContextFactory
|
||||
|
||||
# Request handler class, you probably don't want to override this.
|
||||
protocol_class = GeminiProtocol
|
||||
|
||||
# The TLS twisted interface class is confusingly named SSL4, even though it
|
||||
# will accept either IPv4 & IPv6 interfaces.
|
||||
endpoint_class = SSL4ServerEndpoint
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: typing.Callable,
|
||||
reactor: ReactorBase = reactor,
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 1965,
|
||||
ssl_context: SSL.Context = None,
|
||||
hostname: str = "localhost",
|
||||
) -> None:
|
||||
|
||||
certfile: typing.Optional[str] = None,
|
||||
keyfile: typing.Optional[str] = None,
|
||||
cafile: typing.Optional[str] = None,
|
||||
capath: typing.Optional[str] = None,
|
||||
**_,
|
||||
):
|
||||
self.app = app
|
||||
self.reactor = reactor
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.hostname = hostname
|
||||
self.ssl_context = ssl_context
|
||||
super().__init__((host, port), self.request_handler_class, False)
|
||||
self.certfile = certfile
|
||||
self.keyfile = keyfile
|
||||
self.cafile = cafile
|
||||
self.capath = capath
|
||||
|
||||
def run(self) -> None:
|
||||
def log_message(self, message: str) -> None:
|
||||
"""
|
||||
Launch the main server loop.
|
||||
Log a diagnostic server message to stderr.
|
||||
"""
|
||||
self.log_message(ABOUT)
|
||||
self.log_message(f"Server hostname is {self.hostname}")
|
||||
try:
|
||||
self.server_bind()
|
||||
self.server_activate()
|
||||
except Exception:
|
||||
self.server_close()
|
||||
raise
|
||||
print(message, file=sys.stderr)
|
||||
|
||||
sock_ip, sock_port, *_ = self.server_address
|
||||
if self.address_family == socket.AF_INET:
|
||||
def on_bind_interface(self, port: Port) -> None:
|
||||
"""
|
||||
Log when the server binds to an interface.
|
||||
"""
|
||||
sock_ip, sock_port, *_ = port.socket.getsockname()
|
||||
if port.addressFamily == socket.AF_INET:
|
||||
self.log_message(f"Listening on {sock_ip}:{sock_port}")
|
||||
else:
|
||||
self.log_message(f"Listening on [{sock_ip}]:{sock_port}")
|
||||
|
||||
self.serve_forever()
|
||||
|
||||
def get_request(self) -> typing.Tuple[SSL.Connection, typing.Tuple[str, int]]:
|
||||
def buildProtocol(self, addr) -> GeminiProtocol:
|
||||
"""
|
||||
Wrap the incoming request in an SSL connection.
|
||||
This method is invoked by twisted once for every incoming connection.
|
||||
|
||||
It builds the protocol instance which acts as a request handler and
|
||||
implements the actual Gemini protocol.
|
||||
"""
|
||||
# noinspection PyTupleAssignmentBalance
|
||||
sock, client_addr = super(GeminiServer, self).get_request()
|
||||
sock = SSL.Connection(self.ssl_context, sock)
|
||||
return sock, client_addr
|
||||
return GeminiProtocol(self, self.app)
|
||||
|
||||
def log_message(self, message: str) -> None:
|
||||
def run(self) -> None:
|
||||
"""
|
||||
Log a diagnostic server message to stderr, may be overridden.
|
||||
This is the main server loop.
|
||||
"""
|
||||
print(message, file=sys.stderr)
|
||||
|
||||
|
||||
def generate_ad_hoc_certificate(hostname: str) -> typing.Tuple[str, str]:
|
||||
"""
|
||||
Utility function to generate an ad-hoc self-signed SSL certificate.
|
||||
"""
|
||||
certfile = os.path.join(tempfile.gettempdir(), f"{hostname}.crt")
|
||||
keyfile = os.path.join(tempfile.gettempdir(), f"{hostname}.key")
|
||||
|
||||
if not os.path.exists(certfile) or not os.path.exists(keyfile):
|
||||
backend = default_backend()
|
||||
|
||||
print("Generating private key...", file=sys.stderr)
|
||||
private_key = rsa.generate_private_key(65537, 2048, default_backend())
|
||||
with open(keyfile, "wb") as fp:
|
||||
# noinspection PyTypeChecker
|
||||
key_data = private_key.private_bytes(
|
||||
serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
fp.write(key_data)
|
||||
|
||||
print("Generating certificate...", file=sys.stderr)
|
||||
common_name = x509.NameAttribute(x509.NameOID.COMMON_NAME, hostname)
|
||||
subject_name = x509.Name([common_name])
|
||||
not_valid_before = datetime.datetime.utcnow()
|
||||
not_valid_after = not_valid_before + datetime.timedelta(days=365)
|
||||
certificate = x509.CertificateBuilder(
|
||||
subject_name=subject_name,
|
||||
issuer_name=subject_name,
|
||||
public_key=private_key.public_key(),
|
||||
serial_number=x509.random_serial_number(),
|
||||
not_valid_before=not_valid_before,
|
||||
not_valid_after=not_valid_after,
|
||||
self.log_message(ABOUT)
|
||||
self.log_message(f"Server hostname is {self.hostname}")
|
||||
tls_context_factory = self.tls_context_factory_class(
|
||||
hostname=self.hostname,
|
||||
certfile=self.certfile,
|
||||
keyfile=self.keyfile,
|
||||
cafile=self.cafile,
|
||||
capath=self.capath,
|
||||
)
|
||||
certificate = certificate.sign(private_key, hashes.SHA256(), backend)
|
||||
with open(certfile, "wb") as fp:
|
||||
# noinspection PyTypeChecker
|
||||
cert_data = certificate.public_bytes(serialization.Encoding.PEM)
|
||||
fp.write(cert_data)
|
||||
|
||||
return certfile, keyfile
|
||||
|
||||
|
||||
def make_ssl_context(
|
||||
hostname: str = "localhost",
|
||||
certfile: typing.Optional[str] = None,
|
||||
keyfile: typing.Optional[str] = None,
|
||||
cafile: typing.Optional[str] = None,
|
||||
capath: typing.Optional[str] = None,
|
||||
) -> SSL.Context:
|
||||
"""
|
||||
Generate a sane default SSL context for a Gemini server.
|
||||
"""
|
||||
if certfile is None:
|
||||
certfile, keyfile = generate_ad_hoc_certificate(hostname)
|
||||
|
||||
context = SSL.Context(SSL.TLSv1_2_METHOD)
|
||||
context.use_certificate_file(certfile)
|
||||
context.use_privatekey_file(keyfile or certfile)
|
||||
context.check_privatekey()
|
||||
if cafile or capath:
|
||||
context.load_verify_locations(cafile, capath)
|
||||
|
||||
def verify_cb(connection, x509, err_no, err_depth, return_code):
|
||||
pass
|
||||
|
||||
context.set_verify(SSL.VERIFY_PEER, verify_cb)
|
||||
|
||||
return context
|
||||
endpoint = self.endpoint_class(
|
||||
reactor=self.reactor,
|
||||
port=self.port,
|
||||
sslContextFactory=tls_context_factory,
|
||||
interface=self.host,
|
||||
)
|
||||
endpoint.listen(self).addCallback(self.on_bind_interface)
|
||||
self.reactor.run()
|
||||
|
||||
|
||||
def run_server() -> None:
|
||||
|
@ -816,10 +889,7 @@ def run_server() -> None:
|
|||
"""
|
||||
args = parser.parse_args()
|
||||
app = StaticDirectoryApplication(args.dir, args.index_file, args.cgi_dir)
|
||||
ssl_context = make_ssl_context(
|
||||
args.hostname, args.certfile, args.keyfile, args.cafile, args.capath
|
||||
)
|
||||
server = GeminiServer(app, args.host, args.port, ssl_context, args.hostname)
|
||||
server = GeminiServer(app, **vars(args))
|
||||
server.run()
|
||||
|
||||
|
||||
|
|
2
setup.py
2
setup.py
|
@ -16,7 +16,7 @@ setuptools.setup(
|
|||
author="Michael Lazar",
|
||||
author_email="lazar.michael22@gmail.com",
|
||||
description="An Experimental Gemini Server",
|
||||
install_requires=["cryptography", "pyopenssl"],
|
||||
install_requires=["cryptography", "pyopenssl", "twisted"],
|
||||
long_description=long_description(),
|
||||
long_description_content_type="text/markdown",
|
||||
py_modules=["jetforce", "jetforce_client", "jetforce_diagnostics"],
|
||||
|
|
Loading…
Reference in New Issue