Move server framework from asyncio to socketserver
This commit is contained in:
parent
c369c3b9fd
commit
7996b49792
151
jetforce.py
151
jetforce.py
|
@ -35,7 +35,6 @@ StaticDirectoryApplication:
|
|||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import codecs
|
||||
import dataclasses
|
||||
import datetime
|
||||
|
@ -44,6 +43,7 @@ import os
|
|||
import pathlib
|
||||
import re
|
||||
import socket
|
||||
import socketserver
|
||||
import ssl
|
||||
import subprocess
|
||||
import sys
|
||||
|
@ -453,7 +453,7 @@ class StaticDirectoryApplication(JetforceApplication):
|
|||
|
||||
def load_file(self, filesystem_path: pathlib.Path) -> typing.Iterator[bytes]:
|
||||
"""
|
||||
Load a file using a generator to allow streaming data to the TCP socket.
|
||||
Load a file in chunks to allow streaming to the TCP socket.
|
||||
"""
|
||||
with filesystem_path.open("rb") as fp:
|
||||
data = fp.read(1024)
|
||||
|
@ -514,7 +514,7 @@ class StaticDirectoryApplication(JetforceApplication):
|
|||
return Response(Status.NOT_FOUND, "Not Found")
|
||||
|
||||
|
||||
class GeminiRequestHandler:
|
||||
class GeminiRequestHandler(socketserver.StreamRequestHandler):
|
||||
"""
|
||||
Handle a single Gemini Protocol TCP request.
|
||||
|
||||
|
@ -534,55 +534,40 @@ class GeminiRequestHandler:
|
|||
|
||||
TIMESTAMP_FORMAT = "%d/%b/%Y:%H:%M:%S %z"
|
||||
|
||||
reader: asyncio.StreamReader
|
||||
writer: asyncio.StreamWriter
|
||||
server: GeminiServer
|
||||
received_timestamp: time.struct_time
|
||||
remote_addr: str
|
||||
client_cert: dict
|
||||
url: str
|
||||
status: int
|
||||
meta: str
|
||||
response_buffer: str
|
||||
response_size: int
|
||||
|
||||
def __init__(self, server: GeminiServer, app: typing.Callable) -> None:
|
||||
self.server = server
|
||||
self.app = app
|
||||
def setup(self) -> None:
|
||||
self.response_size = 0
|
||||
self.response_buffer = ""
|
||||
super().setup()
|
||||
|
||||
async def handle(
|
||||
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
|
||||
) -> None:
|
||||
def handle(self) -> None:
|
||||
"""
|
||||
Main method for the request handler, performs the following:
|
||||
|
||||
1. Read the request bytes from the reader stream
|
||||
2. Parse the request and generate response data
|
||||
3. Write the response bytes to the writer stream
|
||||
The request handler entry point, called once for each connection.
|
||||
"""
|
||||
self.reader = reader
|
||||
self.writer = writer
|
||||
self.remote_addr = writer.get_extra_info("peername")[0]
|
||||
self.client_cert = writer.get_extra_info("peercert")
|
||||
self.received_timestamp = time.localtime()
|
||||
self.request.do_handshake()
|
||||
|
||||
try:
|
||||
await self.parse_header()
|
||||
self.parse_header()
|
||||
except Exception:
|
||||
# Malformed request, throw it away and exit immediately
|
||||
self.write_status(Status.BAD_REQUEST, "Malformed request")
|
||||
return await self.close_connection()
|
||||
return
|
||||
|
||||
try:
|
||||
environ = self.build_environ()
|
||||
app = self.app(environ, self.write_status)
|
||||
for data in app:
|
||||
await self.write_body(data)
|
||||
response_generator = self.server.app(environ, self.write_status)
|
||||
for data in response_generator:
|
||||
self.write_body(data)
|
||||
except Exception:
|
||||
self.write_status(Status.CGI_ERROR, "An unexpected error occurred")
|
||||
raise
|
||||
finally:
|
||||
await self.close_connection()
|
||||
|
||||
def build_environ(self) -> typing.Dict[str, typing.Any]:
|
||||
"""
|
||||
|
@ -591,41 +576,43 @@ class GeminiRequestHandler:
|
|||
Variable names conform to the CGI spec defined in RFC 3875.
|
||||
"""
|
||||
url_parts = urllib.parse.urlparse(self.url)
|
||||
remote_addr, remote_port, *_ = self.client_address
|
||||
environ = {
|
||||
"GEMINI_URL": self.url,
|
||||
"HOSTNAME": self.server.hostname,
|
||||
"PATH_INFO": url_parts.path,
|
||||
"QUERY_STRING": url_parts.query,
|
||||
"REMOTE_ADDR": self.remote_addr,
|
||||
"REMOTE_HOST": self.remote_addr,
|
||||
"REMOTE_ADDR": remote_addr,
|
||||
"REMOTE_HOST": remote_addr,
|
||||
"SERVER_NAME": self.server.hostname,
|
||||
"SERVER_PORT": str(self.server.port),
|
||||
"SERVER_PORT": str(remote_port),
|
||||
"SERVER_PROTOCOL": "GEMINI",
|
||||
"SERVER_SOFTWARE": f"jetforce/{__version__}",
|
||||
}
|
||||
|
||||
if self.client_cert:
|
||||
subject = dict(x[0] for x in self.client_cert["subject"])
|
||||
client_cert = self.request.getpeercert()
|
||||
if client_cert:
|
||||
subject = dict(x[0] for x in client_cert["subject"])
|
||||
environ.update(
|
||||
{
|
||||
"AUTH_TYPE": "CERTIFICATE",
|
||||
"REMOTE_USER": subject.get("commonName", ""),
|
||||
"TLS_CLIENT_NOT_BEFORE": self.client_cert["notBefore"],
|
||||
"TLS_CLIENT_NOT_AFTER": self.client_cert["notAfter"],
|
||||
"TLS_CLIENT_SERIAL_NUMBER": self.client_cert["serialNumber"],
|
||||
"TLS_CLIENT_NOT_BEFORE": client_cert["notBefore"],
|
||||
"TLS_CLIENT_NOT_AFTER": client_cert["notAfter"],
|
||||
"TLS_CLIENT_SERIAL_NUMBER": client_cert["serialNumber"],
|
||||
}
|
||||
)
|
||||
|
||||
return environ
|
||||
|
||||
async def parse_header(self) -> None:
|
||||
def parse_header(self) -> None:
|
||||
"""
|
||||
Parse the gemini header line.
|
||||
|
||||
The request is a single UTF-8 line formatted as: <URL>\r\n
|
||||
"""
|
||||
data = await self.reader.readuntil(b"\r\n")
|
||||
data = data[:-2] # strip the line ending
|
||||
data = self.rfile.readline(1026)
|
||||
data = data.rstrip(b"\r\n")
|
||||
if len(data) > 1024:
|
||||
raise ValueError("URL exceeds max length of 1024 bytes")
|
||||
|
||||
|
@ -650,53 +637,48 @@ class GeminiRequestHandler:
|
|||
self.meta = meta
|
||||
self.response_buffer = f"{status}\t{meta}\r\n"
|
||||
|
||||
async def write_body(self, data: bytes) -> None:
|
||||
def write_body(self, data: bytes) -> None:
|
||||
"""
|
||||
Write bytes to the gemini response body.
|
||||
"""
|
||||
await self.flush_status()
|
||||
self.flush_status()
|
||||
self.response_size += len(data)
|
||||
self.writer.write(data)
|
||||
await self.writer.drain()
|
||||
self.wfile.write(data)
|
||||
|
||||
async def flush_status(self) -> None:
|
||||
def flush_status(self) -> None:
|
||||
"""
|
||||
Flush the status line from the internal buffer to the socket stream.
|
||||
"""
|
||||
if self.response_buffer and not self.response_size:
|
||||
data = self.response_buffer.encode()
|
||||
self.response_size += len(data)
|
||||
self.writer.write(data)
|
||||
await self.writer.drain()
|
||||
self.wfile.write(data)
|
||||
self.response_buffer = ""
|
||||
|
||||
async def close_connection(self) -> None:
|
||||
"""
|
||||
Flush any remaining bytes and close the stream.
|
||||
"""
|
||||
await self.flush_status()
|
||||
def finish(self) -> None:
|
||||
self.flush_status()
|
||||
try:
|
||||
self.log_request()
|
||||
await self.writer.drain()
|
||||
except AttributeError:
|
||||
# Malformed request or dropped connection
|
||||
pass
|
||||
super().finish()
|
||||
|
||||
def log_request(self) -> None:
|
||||
"""
|
||||
Log a gemini request using a format derived from the Common Log Format.
|
||||
"""
|
||||
try:
|
||||
self.server.log_message(
|
||||
f"{self.remote_addr} "
|
||||
f"{self.client_address[0]} "
|
||||
f"[{time.strftime(self.TIMESTAMP_FORMAT, self.received_timestamp)}] "
|
||||
f'"{self.url}" '
|
||||
f"{self.status} "
|
||||
f'"{self.meta}" '
|
||||
f"{self.response_size}"
|
||||
)
|
||||
except AttributeError:
|
||||
# Malformed request or dropped connection
|
||||
pass
|
||||
|
||||
|
||||
class GeminiServer:
|
||||
class GeminiServer(socketserver.ThreadingTCPServer):
|
||||
"""
|
||||
An asynchronous TCP server that uses the asyncio stream abstraction.
|
||||
|
||||
|
@ -715,47 +697,46 @@ class GeminiServer:
|
|||
hostname: str = "localhost",
|
||||
) -> None:
|
||||
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.hostname = hostname
|
||||
self.app = app
|
||||
self.hostname = hostname
|
||||
self.ssl_context = ssl_context
|
||||
super().__init__((host, port), self.request_handler_class, False)
|
||||
|
||||
async def run(self) -> None:
|
||||
def run(self) -> None:
|
||||
"""
|
||||
The main asynchronous server loop.
|
||||
Launch the main server loop.
|
||||
"""
|
||||
self.log_message(ABOUT)
|
||||
server = await asyncio.start_server(
|
||||
self.accept_connection, self.host, self.port, ssl=self.ssl_context
|
||||
)
|
||||
|
||||
self.log_message(f"Server hostname is {self.hostname}")
|
||||
for sock in server.sockets:
|
||||
sock_ip, sock_port, *_ = sock.getsockname()
|
||||
if sock.family == socket.AF_INET:
|
||||
try:
|
||||
self.server_bind()
|
||||
self.server_activate()
|
||||
except Exception:
|
||||
self.server_close()
|
||||
raise
|
||||
|
||||
sock_ip, sock_port, *_ = self.server_address
|
||||
if self.address_family == socket.AF_INET:
|
||||
self.log_message(f"Listening on {sock_ip}:{sock_port}")
|
||||
else:
|
||||
self.log_message(f"Listening on [{sock_ip}]:{sock_port}")
|
||||
|
||||
async with server:
|
||||
await server.serve_forever()
|
||||
self.serve_forever()
|
||||
|
||||
async def accept_connection(
|
||||
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
|
||||
) -> None:
|
||||
def get_request(self):
|
||||
"""
|
||||
Hook called by the socket server when a new connection is accepted.
|
||||
Wrap the incoming request in an SSL connection.
|
||||
"""
|
||||
request_handler = self.request_handler_class(self, self.app)
|
||||
try:
|
||||
await request_handler.handle(reader, writer)
|
||||
finally:
|
||||
writer.close()
|
||||
# noinspection PyTupleAssignmentBalance
|
||||
sock, client_addr = super(GeminiServer, self).get_request()
|
||||
ssl_sock = self.ssl_context.wrap_socket(
|
||||
sock, server_side=True, do_handshake_on_connect=False
|
||||
)
|
||||
return ssl_sock, client_addr
|
||||
|
||||
def log_message(self, message: str) -> None:
|
||||
"""
|
||||
Log a diagnostic server message.
|
||||
Log a diagnostic server message to stderr, may be overridden.
|
||||
"""
|
||||
print(message, file=sys.stderr)
|
||||
|
||||
|
@ -852,7 +833,7 @@ def run_server() -> None:
|
|||
args.hostname, args.certfile, args.keyfile, args.cafile, args.capath
|
||||
)
|
||||
server = GeminiServer(app, args.host, args.port, ssl_context, args.hostname)
|
||||
asyncio.run(server.run())
|
||||
server.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue