From 7996b49792b6cba15d17304578ebfe6f03b059da Mon Sep 17 00:00:00 2001 From: Michael Lazar Date: Sat, 9 May 2020 01:34:02 -0400 Subject: [PATCH] Move server framework from asyncio to socketserver --- jetforce.py | 179 +++++++++++++++++++++++----------------------------- 1 file changed, 80 insertions(+), 99 deletions(-) diff --git a/jetforce.py b/jetforce.py index ee3c947..ba1250e 100755 --- a/jetforce.py +++ b/jetforce.py @@ -35,7 +35,6 @@ StaticDirectoryApplication: from __future__ import annotations import argparse -import asyncio import codecs import dataclasses import datetime @@ -44,6 +43,7 @@ import os import pathlib import re import socket +import socketserver import ssl import subprocess import sys @@ -453,7 +453,7 @@ class StaticDirectoryApplication(JetforceApplication): def load_file(self, filesystem_path: pathlib.Path) -> typing.Iterator[bytes]: """ - Load a file using a generator to allow streaming data to the TCP socket. + Load a file in chunks to allow streaming to the TCP socket. """ with filesystem_path.open("rb") as fp: data = fp.read(1024) @@ -514,7 +514,7 @@ class StaticDirectoryApplication(JetforceApplication): return Response(Status.NOT_FOUND, "Not Found") -class GeminiRequestHandler: +class GeminiRequestHandler(socketserver.StreamRequestHandler): """ Handle a single Gemini Protocol TCP request. @@ -534,55 +534,40 @@ class GeminiRequestHandler: TIMESTAMP_FORMAT = "%d/%b/%Y:%H:%M:%S %z" - reader: asyncio.StreamReader - writer: asyncio.StreamWriter + server: GeminiServer received_timestamp: time.struct_time - remote_addr: str - client_cert: dict url: str status: int meta: str response_buffer: str response_size: int - def __init__(self, server: GeminiServer, app: typing.Callable) -> None: - self.server = server - self.app = app + def setup(self) -> None: self.response_size = 0 + self.response_buffer = "" + super().setup() - async def handle( - self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter - ) -> None: + def handle(self) -> None: """ - Main method for the request handler, performs the following: - - 1. Read the request bytes from the reader stream - 2. Parse the request and generate response data - 3. Write the response bytes to the writer stream + The request handler entry point, called once for each connection. """ - self.reader = reader - self.writer = writer - self.remote_addr = writer.get_extra_info("peername")[0] - self.client_cert = writer.get_extra_info("peercert") self.received_timestamp = time.localtime() + self.request.do_handshake() try: - await self.parse_header() + self.parse_header() except Exception: # Malformed request, throw it away and exit immediately self.write_status(Status.BAD_REQUEST, "Malformed request") - return await self.close_connection() + return try: environ = self.build_environ() - app = self.app(environ, self.write_status) - for data in app: - await self.write_body(data) + response_generator = self.server.app(environ, self.write_status) + for data in response_generator: + self.write_body(data) except Exception: self.write_status(Status.CGI_ERROR, "An unexpected error occurred") - raise - finally: - await self.close_connection() def build_environ(self) -> typing.Dict[str, typing.Any]: """ @@ -591,41 +576,43 @@ class GeminiRequestHandler: Variable names conform to the CGI spec defined in RFC 3875. """ url_parts = urllib.parse.urlparse(self.url) + remote_addr, remote_port, *_ = self.client_address environ = { "GEMINI_URL": self.url, "HOSTNAME": self.server.hostname, "PATH_INFO": url_parts.path, "QUERY_STRING": url_parts.query, - "REMOTE_ADDR": self.remote_addr, - "REMOTE_HOST": self.remote_addr, + "REMOTE_ADDR": remote_addr, + "REMOTE_HOST": remote_addr, "SERVER_NAME": self.server.hostname, - "SERVER_PORT": str(self.server.port), + "SERVER_PORT": str(remote_port), "SERVER_PROTOCOL": "GEMINI", "SERVER_SOFTWARE": f"jetforce/{__version__}", } - if self.client_cert: - subject = dict(x[0] for x in self.client_cert["subject"]) + client_cert = self.request.getpeercert() + if client_cert: + subject = dict(x[0] for x in client_cert["subject"]) environ.update( { "AUTH_TYPE": "CERTIFICATE", "REMOTE_USER": subject.get("commonName", ""), - "TLS_CLIENT_NOT_BEFORE": self.client_cert["notBefore"], - "TLS_CLIENT_NOT_AFTER": self.client_cert["notAfter"], - "TLS_CLIENT_SERIAL_NUMBER": self.client_cert["serialNumber"], + "TLS_CLIENT_NOT_BEFORE": client_cert["notBefore"], + "TLS_CLIENT_NOT_AFTER": client_cert["notAfter"], + "TLS_CLIENT_SERIAL_NUMBER": client_cert["serialNumber"], } ) return environ - async def parse_header(self) -> None: + def parse_header(self) -> None: """ Parse the gemini header line. The request is a single UTF-8 line formatted as: \r\n """ - data = await self.reader.readuntil(b"\r\n") - data = data[:-2] # strip the line ending + data = self.rfile.readline(1026) + data = data.rstrip(b"\r\n") if len(data) > 1024: raise ValueError("URL exceeds max length of 1024 bytes") @@ -650,53 +637,48 @@ class GeminiRequestHandler: self.meta = meta self.response_buffer = f"{status}\t{meta}\r\n" - async def write_body(self, data: bytes) -> None: + def write_body(self, data: bytes) -> None: """ Write bytes to the gemini response body. """ - await self.flush_status() + self.flush_status() self.response_size += len(data) - self.writer.write(data) - await self.writer.drain() + self.wfile.write(data) - async def flush_status(self) -> None: + def flush_status(self) -> None: """ Flush the status line from the internal buffer to the socket stream. """ if self.response_buffer and not self.response_size: data = self.response_buffer.encode() self.response_size += len(data) - self.writer.write(data) - await self.writer.drain() + self.wfile.write(data) self.response_buffer = "" - async def close_connection(self) -> None: - """ - Flush any remaining bytes and close the stream. - """ - await self.flush_status() - self.log_request() - await self.writer.drain() + def finish(self) -> None: + self.flush_status() + try: + self.log_request() + except AttributeError: + # Malformed request or dropped connection + pass + super().finish() def log_request(self) -> None: """ Log a gemini request using a format derived from the Common Log Format. """ - try: - self.server.log_message( - f"{self.remote_addr} " - f"[{time.strftime(self.TIMESTAMP_FORMAT, self.received_timestamp)}] " - f'"{self.url}" ' - f"{self.status} " - f'"{self.meta}" ' - f"{self.response_size}" - ) - except AttributeError: - # Malformed request or dropped connection - pass + self.server.log_message( + f"{self.client_address[0]} " + f"[{time.strftime(self.TIMESTAMP_FORMAT, self.received_timestamp)}] " + f'"{self.url}" ' + f"{self.status} " + f'"{self.meta}" ' + f"{self.response_size}" + ) -class GeminiServer: +class GeminiServer(socketserver.ThreadingTCPServer): """ An asynchronous TCP server that uses the asyncio stream abstraction. @@ -715,47 +697,46 @@ class GeminiServer: hostname: str = "localhost", ) -> None: - self.host = host - self.port = port - self.hostname = hostname self.app = app + self.hostname = hostname self.ssl_context = ssl_context + super().__init__((host, port), self.request_handler_class, False) - async def run(self) -> None: + def run(self) -> None: """ - The main asynchronous server loop. + Launch the main server loop. """ self.log_message(ABOUT) - server = await asyncio.start_server( - self.accept_connection, self.host, self.port, ssl=self.ssl_context - ) - self.log_message(f"Server hostname is {self.hostname}") - for sock in server.sockets: - sock_ip, sock_port, *_ = sock.getsockname() - if sock.family == socket.AF_INET: - self.log_message(f"Listening on {sock_ip}:{sock_port}") - else: - self.log_message(f"Listening on [{sock_ip}]:{sock_port}") - - async with server: - await server.serve_forever() - - async def accept_connection( - self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter - ) -> None: - """ - Hook called by the socket server when a new connection is accepted. - """ - request_handler = self.request_handler_class(self, self.app) try: - await request_handler.handle(reader, writer) - finally: - writer.close() + self.server_bind() + self.server_activate() + except Exception: + self.server_close() + raise + + sock_ip, sock_port, *_ = self.server_address + if self.address_family == socket.AF_INET: + self.log_message(f"Listening on {sock_ip}:{sock_port}") + else: + self.log_message(f"Listening on [{sock_ip}]:{sock_port}") + + self.serve_forever() + + def get_request(self): + """ + Wrap the incoming request in an SSL connection. + """ + # noinspection PyTupleAssignmentBalance + sock, client_addr = super(GeminiServer, self).get_request() + ssl_sock = self.ssl_context.wrap_socket( + sock, server_side=True, do_handshake_on_connect=False + ) + return ssl_sock, client_addr def log_message(self, message: str) -> None: """ - Log a diagnostic server message. + Log a diagnostic server message to stderr, may be overridden. """ print(message, file=sys.stderr) @@ -852,7 +833,7 @@ def run_server() -> None: args.hostname, args.certfile, args.keyfile, args.cafile, args.capath ) server = GeminiServer(app, args.host, args.port, ssl_context, args.hostname) - asyncio.run(server.run()) + server.run() if __name__ == "__main__":