Move server framework from asyncio to socketserver

This commit is contained in:
Michael Lazar 2020-05-09 01:34:02 -04:00
parent c369c3b9fd
commit 7996b49792
1 changed files with 80 additions and 99 deletions

View File

@ -35,7 +35,6 @@ StaticDirectoryApplication:
from __future__ import annotations from __future__ import annotations
import argparse import argparse
import asyncio
import codecs import codecs
import dataclasses import dataclasses
import datetime import datetime
@ -44,6 +43,7 @@ import os
import pathlib import pathlib
import re import re
import socket import socket
import socketserver
import ssl import ssl
import subprocess import subprocess
import sys import sys
@ -453,7 +453,7 @@ class StaticDirectoryApplication(JetforceApplication):
def load_file(self, filesystem_path: pathlib.Path) -> typing.Iterator[bytes]: def load_file(self, filesystem_path: pathlib.Path) -> typing.Iterator[bytes]:
""" """
Load a file using a generator to allow streaming data to the TCP socket. Load a file in chunks to allow streaming to the TCP socket.
""" """
with filesystem_path.open("rb") as fp: with filesystem_path.open("rb") as fp:
data = fp.read(1024) data = fp.read(1024)
@ -514,7 +514,7 @@ class StaticDirectoryApplication(JetforceApplication):
return Response(Status.NOT_FOUND, "Not Found") return Response(Status.NOT_FOUND, "Not Found")
class GeminiRequestHandler: class GeminiRequestHandler(socketserver.StreamRequestHandler):
""" """
Handle a single Gemini Protocol TCP request. Handle a single Gemini Protocol TCP request.
@ -534,55 +534,40 @@ class GeminiRequestHandler:
TIMESTAMP_FORMAT = "%d/%b/%Y:%H:%M:%S %z" TIMESTAMP_FORMAT = "%d/%b/%Y:%H:%M:%S %z"
reader: asyncio.StreamReader server: GeminiServer
writer: asyncio.StreamWriter
received_timestamp: time.struct_time received_timestamp: time.struct_time
remote_addr: str
client_cert: dict
url: str url: str
status: int status: int
meta: str meta: str
response_buffer: str response_buffer: str
response_size: int response_size: int
def __init__(self, server: GeminiServer, app: typing.Callable) -> None: def setup(self) -> None:
self.server = server
self.app = app
self.response_size = 0 self.response_size = 0
self.response_buffer = ""
super().setup()
async def handle( def handle(self) -> None:
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
) -> None:
""" """
Main method for the request handler, performs the following: The request handler entry point, called once for each connection.
1. Read the request bytes from the reader stream
2. Parse the request and generate response data
3. Write the response bytes to the writer stream
""" """
self.reader = reader
self.writer = writer
self.remote_addr = writer.get_extra_info("peername")[0]
self.client_cert = writer.get_extra_info("peercert")
self.received_timestamp = time.localtime() self.received_timestamp = time.localtime()
self.request.do_handshake()
try: try:
await self.parse_header() self.parse_header()
except Exception: except Exception:
# Malformed request, throw it away and exit immediately # Malformed request, throw it away and exit immediately
self.write_status(Status.BAD_REQUEST, "Malformed request") self.write_status(Status.BAD_REQUEST, "Malformed request")
return await self.close_connection() return
try: try:
environ = self.build_environ() environ = self.build_environ()
app = self.app(environ, self.write_status) response_generator = self.server.app(environ, self.write_status)
for data in app: for data in response_generator:
await self.write_body(data) self.write_body(data)
except Exception: except Exception:
self.write_status(Status.CGI_ERROR, "An unexpected error occurred") self.write_status(Status.CGI_ERROR, "An unexpected error occurred")
raise
finally:
await self.close_connection()
def build_environ(self) -> typing.Dict[str, typing.Any]: def build_environ(self) -> typing.Dict[str, typing.Any]:
""" """
@ -591,41 +576,43 @@ class GeminiRequestHandler:
Variable names conform to the CGI spec defined in RFC 3875. Variable names conform to the CGI spec defined in RFC 3875.
""" """
url_parts = urllib.parse.urlparse(self.url) url_parts = urllib.parse.urlparse(self.url)
remote_addr, remote_port, *_ = self.client_address
environ = { environ = {
"GEMINI_URL": self.url, "GEMINI_URL": self.url,
"HOSTNAME": self.server.hostname, "HOSTNAME": self.server.hostname,
"PATH_INFO": url_parts.path, "PATH_INFO": url_parts.path,
"QUERY_STRING": url_parts.query, "QUERY_STRING": url_parts.query,
"REMOTE_ADDR": self.remote_addr, "REMOTE_ADDR": remote_addr,
"REMOTE_HOST": self.remote_addr, "REMOTE_HOST": remote_addr,
"SERVER_NAME": self.server.hostname, "SERVER_NAME": self.server.hostname,
"SERVER_PORT": str(self.server.port), "SERVER_PORT": str(remote_port),
"SERVER_PROTOCOL": "GEMINI", "SERVER_PROTOCOL": "GEMINI",
"SERVER_SOFTWARE": f"jetforce/{__version__}", "SERVER_SOFTWARE": f"jetforce/{__version__}",
} }
if self.client_cert: client_cert = self.request.getpeercert()
subject = dict(x[0] for x in self.client_cert["subject"]) if client_cert:
subject = dict(x[0] for x in client_cert["subject"])
environ.update( environ.update(
{ {
"AUTH_TYPE": "CERTIFICATE", "AUTH_TYPE": "CERTIFICATE",
"REMOTE_USER": subject.get("commonName", ""), "REMOTE_USER": subject.get("commonName", ""),
"TLS_CLIENT_NOT_BEFORE": self.client_cert["notBefore"], "TLS_CLIENT_NOT_BEFORE": client_cert["notBefore"],
"TLS_CLIENT_NOT_AFTER": self.client_cert["notAfter"], "TLS_CLIENT_NOT_AFTER": client_cert["notAfter"],
"TLS_CLIENT_SERIAL_NUMBER": self.client_cert["serialNumber"], "TLS_CLIENT_SERIAL_NUMBER": client_cert["serialNumber"],
} }
) )
return environ return environ
async def parse_header(self) -> None: def parse_header(self) -> None:
""" """
Parse the gemini header line. Parse the gemini header line.
The request is a single UTF-8 line formatted as: <URL>\r\n The request is a single UTF-8 line formatted as: <URL>\r\n
""" """
data = await self.reader.readuntil(b"\r\n") data = self.rfile.readline(1026)
data = data[:-2] # strip the line ending data = data.rstrip(b"\r\n")
if len(data) > 1024: if len(data) > 1024:
raise ValueError("URL exceeds max length of 1024 bytes") raise ValueError("URL exceeds max length of 1024 bytes")
@ -650,53 +637,48 @@ class GeminiRequestHandler:
self.meta = meta self.meta = meta
self.response_buffer = f"{status}\t{meta}\r\n" self.response_buffer = f"{status}\t{meta}\r\n"
async def write_body(self, data: bytes) -> None: def write_body(self, data: bytes) -> None:
""" """
Write bytes to the gemini response body. Write bytes to the gemini response body.
""" """
await self.flush_status() self.flush_status()
self.response_size += len(data) self.response_size += len(data)
self.writer.write(data) self.wfile.write(data)
await self.writer.drain()
async def flush_status(self) -> None: def flush_status(self) -> None:
""" """
Flush the status line from the internal buffer to the socket stream. Flush the status line from the internal buffer to the socket stream.
""" """
if self.response_buffer and not self.response_size: if self.response_buffer and not self.response_size:
data = self.response_buffer.encode() data = self.response_buffer.encode()
self.response_size += len(data) self.response_size += len(data)
self.writer.write(data) self.wfile.write(data)
await self.writer.drain()
self.response_buffer = "" self.response_buffer = ""
async def close_connection(self) -> None: def finish(self) -> None:
""" self.flush_status()
Flush any remaining bytes and close the stream. try:
""" self.log_request()
await self.flush_status() except AttributeError:
self.log_request() # Malformed request or dropped connection
await self.writer.drain() pass
super().finish()
def log_request(self) -> None: def log_request(self) -> None:
""" """
Log a gemini request using a format derived from the Common Log Format. Log a gemini request using a format derived from the Common Log Format.
""" """
try: self.server.log_message(
self.server.log_message( f"{self.client_address[0]} "
f"{self.remote_addr} " f"[{time.strftime(self.TIMESTAMP_FORMAT, self.received_timestamp)}] "
f"[{time.strftime(self.TIMESTAMP_FORMAT, self.received_timestamp)}] " f'"{self.url}" '
f'"{self.url}" ' f"{self.status} "
f"{self.status} " f'"{self.meta}" '
f'"{self.meta}" ' f"{self.response_size}"
f"{self.response_size}" )
)
except AttributeError:
# Malformed request or dropped connection
pass
class GeminiServer: class GeminiServer(socketserver.ThreadingTCPServer):
""" """
An asynchronous TCP server that uses the asyncio stream abstraction. An asynchronous TCP server that uses the asyncio stream abstraction.
@ -715,47 +697,46 @@ class GeminiServer:
hostname: str = "localhost", hostname: str = "localhost",
) -> None: ) -> None:
self.host = host
self.port = port
self.hostname = hostname
self.app = app self.app = app
self.hostname = hostname
self.ssl_context = ssl_context self.ssl_context = ssl_context
super().__init__((host, port), self.request_handler_class, False)
async def run(self) -> None: def run(self) -> None:
""" """
The main asynchronous server loop. Launch the main server loop.
""" """
self.log_message(ABOUT) self.log_message(ABOUT)
server = await asyncio.start_server(
self.accept_connection, self.host, self.port, ssl=self.ssl_context
)
self.log_message(f"Server hostname is {self.hostname}") self.log_message(f"Server hostname is {self.hostname}")
for sock in server.sockets:
sock_ip, sock_port, *_ = sock.getsockname()
if sock.family == socket.AF_INET:
self.log_message(f"Listening on {sock_ip}:{sock_port}")
else:
self.log_message(f"Listening on [{sock_ip}]:{sock_port}")
async with server:
await server.serve_forever()
async def accept_connection(
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
) -> None:
"""
Hook called by the socket server when a new connection is accepted.
"""
request_handler = self.request_handler_class(self, self.app)
try: try:
await request_handler.handle(reader, writer) self.server_bind()
finally: self.server_activate()
writer.close() except Exception:
self.server_close()
raise
sock_ip, sock_port, *_ = self.server_address
if self.address_family == socket.AF_INET:
self.log_message(f"Listening on {sock_ip}:{sock_port}")
else:
self.log_message(f"Listening on [{sock_ip}]:{sock_port}")
self.serve_forever()
def get_request(self):
"""
Wrap the incoming request in an SSL connection.
"""
# noinspection PyTupleAssignmentBalance
sock, client_addr = super(GeminiServer, self).get_request()
ssl_sock = self.ssl_context.wrap_socket(
sock, server_side=True, do_handshake_on_connect=False
)
return ssl_sock, client_addr
def log_message(self, message: str) -> None: def log_message(self, message: str) -> None:
""" """
Log a diagnostic server message. Log a diagnostic server message to stderr, may be overridden.
""" """
print(message, file=sys.stderr) print(message, file=sys.stderr)
@ -852,7 +833,7 @@ def run_server() -> None:
args.hostname, args.certfile, args.keyfile, args.cafile, args.capath args.hostname, args.certfile, args.keyfile, args.cafile, args.capath
) )
server = GeminiServer(app, args.host, args.port, ssl_context, args.hostname) server = GeminiServer(app, args.host, args.port, ssl_context, args.hostname)
asyncio.run(server.run()) server.run()
if __name__ == "__main__": if __name__ == "__main__":