Move server framework from asyncio to socketserver

This commit is contained in:
Michael Lazar 2020-05-09 01:34:02 -04:00
parent c369c3b9fd
commit 7996b49792
1 changed files with 80 additions and 99 deletions

View File

@ -35,7 +35,6 @@ StaticDirectoryApplication:
from __future__ import annotations
import argparse
import asyncio
import codecs
import dataclasses
import datetime
@ -44,6 +43,7 @@ import os
import pathlib
import re
import socket
import socketserver
import ssl
import subprocess
import sys
@ -453,7 +453,7 @@ class StaticDirectoryApplication(JetforceApplication):
def load_file(self, filesystem_path: pathlib.Path) -> typing.Iterator[bytes]:
"""
Load a file using a generator to allow streaming data to the TCP socket.
Load a file in chunks to allow streaming to the TCP socket.
"""
with filesystem_path.open("rb") as fp:
data = fp.read(1024)
@ -514,7 +514,7 @@ class StaticDirectoryApplication(JetforceApplication):
return Response(Status.NOT_FOUND, "Not Found")
class GeminiRequestHandler:
class GeminiRequestHandler(socketserver.StreamRequestHandler):
"""
Handle a single Gemini Protocol TCP request.
@ -534,55 +534,40 @@ class GeminiRequestHandler:
TIMESTAMP_FORMAT = "%d/%b/%Y:%H:%M:%S %z"
reader: asyncio.StreamReader
writer: asyncio.StreamWriter
server: GeminiServer
received_timestamp: time.struct_time
remote_addr: str
client_cert: dict
url: str
status: int
meta: str
response_buffer: str
response_size: int
def __init__(self, server: GeminiServer, app: typing.Callable) -> None:
self.server = server
self.app = app
def setup(self) -> None:
self.response_size = 0
self.response_buffer = ""
super().setup()
async def handle(
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
) -> None:
def handle(self) -> None:
"""
Main method for the request handler, performs the following:
1. Read the request bytes from the reader stream
2. Parse the request and generate response data
3. Write the response bytes to the writer stream
The request handler entry point, called once for each connection.
"""
self.reader = reader
self.writer = writer
self.remote_addr = writer.get_extra_info("peername")[0]
self.client_cert = writer.get_extra_info("peercert")
self.received_timestamp = time.localtime()
self.request.do_handshake()
try:
await self.parse_header()
self.parse_header()
except Exception:
# Malformed request, throw it away and exit immediately
self.write_status(Status.BAD_REQUEST, "Malformed request")
return await self.close_connection()
return
try:
environ = self.build_environ()
app = self.app(environ, self.write_status)
for data in app:
await self.write_body(data)
response_generator = self.server.app(environ, self.write_status)
for data in response_generator:
self.write_body(data)
except Exception:
self.write_status(Status.CGI_ERROR, "An unexpected error occurred")
raise
finally:
await self.close_connection()
def build_environ(self) -> typing.Dict[str, typing.Any]:
"""
@ -591,41 +576,43 @@ class GeminiRequestHandler:
Variable names conform to the CGI spec defined in RFC 3875.
"""
url_parts = urllib.parse.urlparse(self.url)
remote_addr, remote_port, *_ = self.client_address
environ = {
"GEMINI_URL": self.url,
"HOSTNAME": self.server.hostname,
"PATH_INFO": url_parts.path,
"QUERY_STRING": url_parts.query,
"REMOTE_ADDR": self.remote_addr,
"REMOTE_HOST": self.remote_addr,
"REMOTE_ADDR": remote_addr,
"REMOTE_HOST": remote_addr,
"SERVER_NAME": self.server.hostname,
"SERVER_PORT": str(self.server.port),
"SERVER_PORT": str(remote_port),
"SERVER_PROTOCOL": "GEMINI",
"SERVER_SOFTWARE": f"jetforce/{__version__}",
}
if self.client_cert:
subject = dict(x[0] for x in self.client_cert["subject"])
client_cert = self.request.getpeercert()
if client_cert:
subject = dict(x[0] for x in client_cert["subject"])
environ.update(
{
"AUTH_TYPE": "CERTIFICATE",
"REMOTE_USER": subject.get("commonName", ""),
"TLS_CLIENT_NOT_BEFORE": self.client_cert["notBefore"],
"TLS_CLIENT_NOT_AFTER": self.client_cert["notAfter"],
"TLS_CLIENT_SERIAL_NUMBER": self.client_cert["serialNumber"],
"TLS_CLIENT_NOT_BEFORE": client_cert["notBefore"],
"TLS_CLIENT_NOT_AFTER": client_cert["notAfter"],
"TLS_CLIENT_SERIAL_NUMBER": client_cert["serialNumber"],
}
)
return environ
async def parse_header(self) -> None:
def parse_header(self) -> None:
"""
Parse the gemini header line.
The request is a single UTF-8 line formatted as: <URL>\r\n
"""
data = await self.reader.readuntil(b"\r\n")
data = data[:-2] # strip the line ending
data = self.rfile.readline(1026)
data = data.rstrip(b"\r\n")
if len(data) > 1024:
raise ValueError("URL exceeds max length of 1024 bytes")
@ -650,53 +637,48 @@ class GeminiRequestHandler:
self.meta = meta
self.response_buffer = f"{status}\t{meta}\r\n"
async def write_body(self, data: bytes) -> None:
def write_body(self, data: bytes) -> None:
"""
Write bytes to the gemini response body.
"""
await self.flush_status()
self.flush_status()
self.response_size += len(data)
self.writer.write(data)
await self.writer.drain()
self.wfile.write(data)
async def flush_status(self) -> None:
def flush_status(self) -> None:
"""
Flush the status line from the internal buffer to the socket stream.
"""
if self.response_buffer and not self.response_size:
data = self.response_buffer.encode()
self.response_size += len(data)
self.writer.write(data)
await self.writer.drain()
self.wfile.write(data)
self.response_buffer = ""
async def close_connection(self) -> None:
"""
Flush any remaining bytes and close the stream.
"""
await self.flush_status()
self.log_request()
await self.writer.drain()
def finish(self) -> None:
self.flush_status()
try:
self.log_request()
except AttributeError:
# Malformed request or dropped connection
pass
super().finish()
def log_request(self) -> None:
"""
Log a gemini request using a format derived from the Common Log Format.
"""
try:
self.server.log_message(
f"{self.remote_addr} "
f"[{time.strftime(self.TIMESTAMP_FORMAT, self.received_timestamp)}] "
f'"{self.url}" '
f"{self.status} "
f'"{self.meta}" '
f"{self.response_size}"
)
except AttributeError:
# Malformed request or dropped connection
pass
self.server.log_message(
f"{self.client_address[0]} "
f"[{time.strftime(self.TIMESTAMP_FORMAT, self.received_timestamp)}] "
f'"{self.url}" '
f"{self.status} "
f'"{self.meta}" '
f"{self.response_size}"
)
class GeminiServer:
class GeminiServer(socketserver.ThreadingTCPServer):
"""
An asynchronous TCP server that uses the asyncio stream abstraction.
@ -715,47 +697,46 @@ class GeminiServer:
hostname: str = "localhost",
) -> None:
self.host = host
self.port = port
self.hostname = hostname
self.app = app
self.hostname = hostname
self.ssl_context = ssl_context
super().__init__((host, port), self.request_handler_class, False)
async def run(self) -> None:
def run(self) -> None:
"""
The main asynchronous server loop.
Launch the main server loop.
"""
self.log_message(ABOUT)
server = await asyncio.start_server(
self.accept_connection, self.host, self.port, ssl=self.ssl_context
)
self.log_message(f"Server hostname is {self.hostname}")
for sock in server.sockets:
sock_ip, sock_port, *_ = sock.getsockname()
if sock.family == socket.AF_INET:
self.log_message(f"Listening on {sock_ip}:{sock_port}")
else:
self.log_message(f"Listening on [{sock_ip}]:{sock_port}")
async with server:
await server.serve_forever()
async def accept_connection(
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
) -> None:
"""
Hook called by the socket server when a new connection is accepted.
"""
request_handler = self.request_handler_class(self, self.app)
try:
await request_handler.handle(reader, writer)
finally:
writer.close()
self.server_bind()
self.server_activate()
except Exception:
self.server_close()
raise
sock_ip, sock_port, *_ = self.server_address
if self.address_family == socket.AF_INET:
self.log_message(f"Listening on {sock_ip}:{sock_port}")
else:
self.log_message(f"Listening on [{sock_ip}]:{sock_port}")
self.serve_forever()
def get_request(self):
"""
Wrap the incoming request in an SSL connection.
"""
# noinspection PyTupleAssignmentBalance
sock, client_addr = super(GeminiServer, self).get_request()
ssl_sock = self.ssl_context.wrap_socket(
sock, server_side=True, do_handshake_on_connect=False
)
return ssl_sock, client_addr
def log_message(self, message: str) -> None:
"""
Log a diagnostic server message.
Log a diagnostic server message to stderr, may be overridden.
"""
print(message, file=sys.stderr)
@ -852,7 +833,7 @@ def run_server() -> None:
args.hostname, args.certfile, args.keyfile, args.cafile, args.capath
)
server = GeminiServer(app, args.host, args.port, ssl_context, args.hostname)
asyncio.run(server.run())
server.run()
if __name__ == "__main__":