Move server framework from asyncio to socketserver
This commit is contained in:
parent
c369c3b9fd
commit
7996b49792
151
jetforce.py
151
jetforce.py
|
@ -35,7 +35,6 @@ StaticDirectoryApplication:
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
|
||||||
import codecs
|
import codecs
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import datetime
|
import datetime
|
||||||
|
@ -44,6 +43,7 @@ import os
|
||||||
import pathlib
|
import pathlib
|
||||||
import re
|
import re
|
||||||
import socket
|
import socket
|
||||||
|
import socketserver
|
||||||
import ssl
|
import ssl
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
@ -453,7 +453,7 @@ class StaticDirectoryApplication(JetforceApplication):
|
||||||
|
|
||||||
def load_file(self, filesystem_path: pathlib.Path) -> typing.Iterator[bytes]:
|
def load_file(self, filesystem_path: pathlib.Path) -> typing.Iterator[bytes]:
|
||||||
"""
|
"""
|
||||||
Load a file using a generator to allow streaming data to the TCP socket.
|
Load a file in chunks to allow streaming to the TCP socket.
|
||||||
"""
|
"""
|
||||||
with filesystem_path.open("rb") as fp:
|
with filesystem_path.open("rb") as fp:
|
||||||
data = fp.read(1024)
|
data = fp.read(1024)
|
||||||
|
@ -514,7 +514,7 @@ class StaticDirectoryApplication(JetforceApplication):
|
||||||
return Response(Status.NOT_FOUND, "Not Found")
|
return Response(Status.NOT_FOUND, "Not Found")
|
||||||
|
|
||||||
|
|
||||||
class GeminiRequestHandler:
|
class GeminiRequestHandler(socketserver.StreamRequestHandler):
|
||||||
"""
|
"""
|
||||||
Handle a single Gemini Protocol TCP request.
|
Handle a single Gemini Protocol TCP request.
|
||||||
|
|
||||||
|
@ -534,55 +534,40 @@ class GeminiRequestHandler:
|
||||||
|
|
||||||
TIMESTAMP_FORMAT = "%d/%b/%Y:%H:%M:%S %z"
|
TIMESTAMP_FORMAT = "%d/%b/%Y:%H:%M:%S %z"
|
||||||
|
|
||||||
reader: asyncio.StreamReader
|
server: GeminiServer
|
||||||
writer: asyncio.StreamWriter
|
|
||||||
received_timestamp: time.struct_time
|
received_timestamp: time.struct_time
|
||||||
remote_addr: str
|
|
||||||
client_cert: dict
|
|
||||||
url: str
|
url: str
|
||||||
status: int
|
status: int
|
||||||
meta: str
|
meta: str
|
||||||
response_buffer: str
|
response_buffer: str
|
||||||
response_size: int
|
response_size: int
|
||||||
|
|
||||||
def __init__(self, server: GeminiServer, app: typing.Callable) -> None:
|
def setup(self) -> None:
|
||||||
self.server = server
|
|
||||||
self.app = app
|
|
||||||
self.response_size = 0
|
self.response_size = 0
|
||||||
|
self.response_buffer = ""
|
||||||
|
super().setup()
|
||||||
|
|
||||||
async def handle(
|
def handle(self) -> None:
|
||||||
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
Main method for the request handler, performs the following:
|
The request handler entry point, called once for each connection.
|
||||||
|
|
||||||
1. Read the request bytes from the reader stream
|
|
||||||
2. Parse the request and generate response data
|
|
||||||
3. Write the response bytes to the writer stream
|
|
||||||
"""
|
"""
|
||||||
self.reader = reader
|
|
||||||
self.writer = writer
|
|
||||||
self.remote_addr = writer.get_extra_info("peername")[0]
|
|
||||||
self.client_cert = writer.get_extra_info("peercert")
|
|
||||||
self.received_timestamp = time.localtime()
|
self.received_timestamp = time.localtime()
|
||||||
|
self.request.do_handshake()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.parse_header()
|
self.parse_header()
|
||||||
except Exception:
|
except Exception:
|
||||||
# Malformed request, throw it away and exit immediately
|
# Malformed request, throw it away and exit immediately
|
||||||
self.write_status(Status.BAD_REQUEST, "Malformed request")
|
self.write_status(Status.BAD_REQUEST, "Malformed request")
|
||||||
return await self.close_connection()
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
environ = self.build_environ()
|
environ = self.build_environ()
|
||||||
app = self.app(environ, self.write_status)
|
response_generator = self.server.app(environ, self.write_status)
|
||||||
for data in app:
|
for data in response_generator:
|
||||||
await self.write_body(data)
|
self.write_body(data)
|
||||||
except Exception:
|
except Exception:
|
||||||
self.write_status(Status.CGI_ERROR, "An unexpected error occurred")
|
self.write_status(Status.CGI_ERROR, "An unexpected error occurred")
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
await self.close_connection()
|
|
||||||
|
|
||||||
def build_environ(self) -> typing.Dict[str, typing.Any]:
|
def build_environ(self) -> typing.Dict[str, typing.Any]:
|
||||||
"""
|
"""
|
||||||
|
@ -591,41 +576,43 @@ class GeminiRequestHandler:
|
||||||
Variable names conform to the CGI spec defined in RFC 3875.
|
Variable names conform to the CGI spec defined in RFC 3875.
|
||||||
"""
|
"""
|
||||||
url_parts = urllib.parse.urlparse(self.url)
|
url_parts = urllib.parse.urlparse(self.url)
|
||||||
|
remote_addr, remote_port, *_ = self.client_address
|
||||||
environ = {
|
environ = {
|
||||||
"GEMINI_URL": self.url,
|
"GEMINI_URL": self.url,
|
||||||
"HOSTNAME": self.server.hostname,
|
"HOSTNAME": self.server.hostname,
|
||||||
"PATH_INFO": url_parts.path,
|
"PATH_INFO": url_parts.path,
|
||||||
"QUERY_STRING": url_parts.query,
|
"QUERY_STRING": url_parts.query,
|
||||||
"REMOTE_ADDR": self.remote_addr,
|
"REMOTE_ADDR": remote_addr,
|
||||||
"REMOTE_HOST": self.remote_addr,
|
"REMOTE_HOST": remote_addr,
|
||||||
"SERVER_NAME": self.server.hostname,
|
"SERVER_NAME": self.server.hostname,
|
||||||
"SERVER_PORT": str(self.server.port),
|
"SERVER_PORT": str(remote_port),
|
||||||
"SERVER_PROTOCOL": "GEMINI",
|
"SERVER_PROTOCOL": "GEMINI",
|
||||||
"SERVER_SOFTWARE": f"jetforce/{__version__}",
|
"SERVER_SOFTWARE": f"jetforce/{__version__}",
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.client_cert:
|
client_cert = self.request.getpeercert()
|
||||||
subject = dict(x[0] for x in self.client_cert["subject"])
|
if client_cert:
|
||||||
|
subject = dict(x[0] for x in client_cert["subject"])
|
||||||
environ.update(
|
environ.update(
|
||||||
{
|
{
|
||||||
"AUTH_TYPE": "CERTIFICATE",
|
"AUTH_TYPE": "CERTIFICATE",
|
||||||
"REMOTE_USER": subject.get("commonName", ""),
|
"REMOTE_USER": subject.get("commonName", ""),
|
||||||
"TLS_CLIENT_NOT_BEFORE": self.client_cert["notBefore"],
|
"TLS_CLIENT_NOT_BEFORE": client_cert["notBefore"],
|
||||||
"TLS_CLIENT_NOT_AFTER": self.client_cert["notAfter"],
|
"TLS_CLIENT_NOT_AFTER": client_cert["notAfter"],
|
||||||
"TLS_CLIENT_SERIAL_NUMBER": self.client_cert["serialNumber"],
|
"TLS_CLIENT_SERIAL_NUMBER": client_cert["serialNumber"],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
return environ
|
return environ
|
||||||
|
|
||||||
async def parse_header(self) -> None:
|
def parse_header(self) -> None:
|
||||||
"""
|
"""
|
||||||
Parse the gemini header line.
|
Parse the gemini header line.
|
||||||
|
|
||||||
The request is a single UTF-8 line formatted as: <URL>\r\n
|
The request is a single UTF-8 line formatted as: <URL>\r\n
|
||||||
"""
|
"""
|
||||||
data = await self.reader.readuntil(b"\r\n")
|
data = self.rfile.readline(1026)
|
||||||
data = data[:-2] # strip the line ending
|
data = data.rstrip(b"\r\n")
|
||||||
if len(data) > 1024:
|
if len(data) > 1024:
|
||||||
raise ValueError("URL exceeds max length of 1024 bytes")
|
raise ValueError("URL exceeds max length of 1024 bytes")
|
||||||
|
|
||||||
|
@ -650,53 +637,48 @@ class GeminiRequestHandler:
|
||||||
self.meta = meta
|
self.meta = meta
|
||||||
self.response_buffer = f"{status}\t{meta}\r\n"
|
self.response_buffer = f"{status}\t{meta}\r\n"
|
||||||
|
|
||||||
async def write_body(self, data: bytes) -> None:
|
def write_body(self, data: bytes) -> None:
|
||||||
"""
|
"""
|
||||||
Write bytes to the gemini response body.
|
Write bytes to the gemini response body.
|
||||||
"""
|
"""
|
||||||
await self.flush_status()
|
self.flush_status()
|
||||||
self.response_size += len(data)
|
self.response_size += len(data)
|
||||||
self.writer.write(data)
|
self.wfile.write(data)
|
||||||
await self.writer.drain()
|
|
||||||
|
|
||||||
async def flush_status(self) -> None:
|
def flush_status(self) -> None:
|
||||||
"""
|
"""
|
||||||
Flush the status line from the internal buffer to the socket stream.
|
Flush the status line from the internal buffer to the socket stream.
|
||||||
"""
|
"""
|
||||||
if self.response_buffer and not self.response_size:
|
if self.response_buffer and not self.response_size:
|
||||||
data = self.response_buffer.encode()
|
data = self.response_buffer.encode()
|
||||||
self.response_size += len(data)
|
self.response_size += len(data)
|
||||||
self.writer.write(data)
|
self.wfile.write(data)
|
||||||
await self.writer.drain()
|
|
||||||
self.response_buffer = ""
|
self.response_buffer = ""
|
||||||
|
|
||||||
async def close_connection(self) -> None:
|
def finish(self) -> None:
|
||||||
"""
|
self.flush_status()
|
||||||
Flush any remaining bytes and close the stream.
|
try:
|
||||||
"""
|
|
||||||
await self.flush_status()
|
|
||||||
self.log_request()
|
self.log_request()
|
||||||
await self.writer.drain()
|
except AttributeError:
|
||||||
|
# Malformed request or dropped connection
|
||||||
|
pass
|
||||||
|
super().finish()
|
||||||
|
|
||||||
def log_request(self) -> None:
|
def log_request(self) -> None:
|
||||||
"""
|
"""
|
||||||
Log a gemini request using a format derived from the Common Log Format.
|
Log a gemini request using a format derived from the Common Log Format.
|
||||||
"""
|
"""
|
||||||
try:
|
|
||||||
self.server.log_message(
|
self.server.log_message(
|
||||||
f"{self.remote_addr} "
|
f"{self.client_address[0]} "
|
||||||
f"[{time.strftime(self.TIMESTAMP_FORMAT, self.received_timestamp)}] "
|
f"[{time.strftime(self.TIMESTAMP_FORMAT, self.received_timestamp)}] "
|
||||||
f'"{self.url}" '
|
f'"{self.url}" '
|
||||||
f"{self.status} "
|
f"{self.status} "
|
||||||
f'"{self.meta}" '
|
f'"{self.meta}" '
|
||||||
f"{self.response_size}"
|
f"{self.response_size}"
|
||||||
)
|
)
|
||||||
except AttributeError:
|
|
||||||
# Malformed request or dropped connection
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class GeminiServer:
|
class GeminiServer(socketserver.ThreadingTCPServer):
|
||||||
"""
|
"""
|
||||||
An asynchronous TCP server that uses the asyncio stream abstraction.
|
An asynchronous TCP server that uses the asyncio stream abstraction.
|
||||||
|
|
||||||
|
@ -715,47 +697,46 @@ class GeminiServer:
|
||||||
hostname: str = "localhost",
|
hostname: str = "localhost",
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
self.host = host
|
|
||||||
self.port = port
|
|
||||||
self.hostname = hostname
|
|
||||||
self.app = app
|
self.app = app
|
||||||
|
self.hostname = hostname
|
||||||
self.ssl_context = ssl_context
|
self.ssl_context = ssl_context
|
||||||
|
super().__init__((host, port), self.request_handler_class, False)
|
||||||
|
|
||||||
async def run(self) -> None:
|
def run(self) -> None:
|
||||||
"""
|
"""
|
||||||
The main asynchronous server loop.
|
Launch the main server loop.
|
||||||
"""
|
"""
|
||||||
self.log_message(ABOUT)
|
self.log_message(ABOUT)
|
||||||
server = await asyncio.start_server(
|
|
||||||
self.accept_connection, self.host, self.port, ssl=self.ssl_context
|
|
||||||
)
|
|
||||||
|
|
||||||
self.log_message(f"Server hostname is {self.hostname}")
|
self.log_message(f"Server hostname is {self.hostname}")
|
||||||
for sock in server.sockets:
|
try:
|
||||||
sock_ip, sock_port, *_ = sock.getsockname()
|
self.server_bind()
|
||||||
if sock.family == socket.AF_INET:
|
self.server_activate()
|
||||||
|
except Exception:
|
||||||
|
self.server_close()
|
||||||
|
raise
|
||||||
|
|
||||||
|
sock_ip, sock_port, *_ = self.server_address
|
||||||
|
if self.address_family == socket.AF_INET:
|
||||||
self.log_message(f"Listening on {sock_ip}:{sock_port}")
|
self.log_message(f"Listening on {sock_ip}:{sock_port}")
|
||||||
else:
|
else:
|
||||||
self.log_message(f"Listening on [{sock_ip}]:{sock_port}")
|
self.log_message(f"Listening on [{sock_ip}]:{sock_port}")
|
||||||
|
|
||||||
async with server:
|
self.serve_forever()
|
||||||
await server.serve_forever()
|
|
||||||
|
|
||||||
async def accept_connection(
|
def get_request(self):
|
||||||
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
Hook called by the socket server when a new connection is accepted.
|
Wrap the incoming request in an SSL connection.
|
||||||
"""
|
"""
|
||||||
request_handler = self.request_handler_class(self, self.app)
|
# noinspection PyTupleAssignmentBalance
|
||||||
try:
|
sock, client_addr = super(GeminiServer, self).get_request()
|
||||||
await request_handler.handle(reader, writer)
|
ssl_sock = self.ssl_context.wrap_socket(
|
||||||
finally:
|
sock, server_side=True, do_handshake_on_connect=False
|
||||||
writer.close()
|
)
|
||||||
|
return ssl_sock, client_addr
|
||||||
|
|
||||||
def log_message(self, message: str) -> None:
|
def log_message(self, message: str) -> None:
|
||||||
"""
|
"""
|
||||||
Log a diagnostic server message.
|
Log a diagnostic server message to stderr, may be overridden.
|
||||||
"""
|
"""
|
||||||
print(message, file=sys.stderr)
|
print(message, file=sys.stderr)
|
||||||
|
|
||||||
|
@ -852,7 +833,7 @@ def run_server() -> None:
|
||||||
args.hostname, args.certfile, args.keyfile, args.cafile, args.capath
|
args.hostname, args.certfile, args.keyfile, args.cafile, args.capath
|
||||||
)
|
)
|
||||||
server = GeminiServer(app, args.host, args.port, ssl_context, args.hostname)
|
server = GeminiServer(app, args.host, args.port, ssl_context, args.hostname)
|
||||||
asyncio.run(server.run())
|
server.run()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in New Issue