Polishing up the twisted interface

This commit is contained in:
Michael Lazar 2020-05-13 00:32:51 -04:00
parent 335d79ad54
commit 124de25502
1 changed files with 47 additions and 36 deletions

View File

@ -57,6 +57,7 @@ from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.asymmetric import rsa
from OpenSSL import SSL from OpenSSL import SSL
from twisted.internet import reactor from twisted.internet import reactor
from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.base import ReactorBase from twisted.internet.base import ReactorBase
from twisted.internet.endpoints import SSL4ServerEndpoint from twisted.internet.endpoints import SSL4ServerEndpoint
from twisted.internet.protocol import Factory from twisted.internet.protocol import Factory
@ -626,6 +627,8 @@ class GeminiProtocol(LineOnlyReceiver):
TIMESTAMP_FORMAT = "%d/%b/%Y:%H:%M:%S %z" TIMESTAMP_FORMAT = "%d/%b/%Y:%H:%M:%S %z"
client_addr: typing.Union[IPv4Address, IPv6Address]
client_cert: typing.Optional[x509.Certificate]
connected_timestamp: time.struct_time connected_timestamp: time.struct_time
request: bytes request: bytes
url: str url: str
@ -645,6 +648,12 @@ class GeminiProtocol(LineOnlyReceiver):
self.connected_timestamp = time.localtime() self.connected_timestamp = time.localtime()
self.response_size = 0 self.response_size = 0
self.response_buffer = "" self.response_buffer = ""
self.client_addr = self.transport.getPeer()
self.client_cert = None
peer_cert = self.transport.getPeerCertificate()
if peer_cert:
self.client_cert = peer_cert.to_cryptography()
def lineReceived(self, line): def lineReceived(self, line):
""" """
@ -656,25 +665,19 @@ class GeminiProtocol(LineOnlyReceiver):
""" """
self.request = line self.request = line
try: try:
try: self.handle_request()
self.handle()
finally:
self.flush_status()
try:
self.log_request()
except Exception:
# Malformed request or dropped connection
pass
finally: finally:
self.log_request()
self.transport.loseConnection() self.transport.loseConnection()
def handle(self): def handle_request(self):
try: try:
self.parse_header() self.parse_header()
except Exception: except Exception:
# Malformed request, throw it away and exit immediately # Malformed request, throw it away and exit immediately
self.write_status(Status.BAD_REQUEST, "Malformed request") self.write_status(Status.BAD_REQUEST, "Malformed request")
return self.flush_status()
raise
try: try:
environ = self.build_environ() environ = self.build_environ()
@ -683,6 +686,9 @@ class GeminiProtocol(LineOnlyReceiver):
self.write_body(data) self.write_body(data)
except Exception: except Exception:
self.write_status(Status.CGI_ERROR, "An unexpected error occurred") self.write_status(Status.CGI_ERROR, "An unexpected error occurred")
raise
finally:
self.flush_status()
def build_environ(self) -> typing.Dict[str, typing.Any]: def build_environ(self) -> typing.Dict[str, typing.Any]:
""" """
@ -691,25 +697,22 @@ class GeminiProtocol(LineOnlyReceiver):
Variable names conform to the CGI spec defined in RFC 3875. Variable names conform to the CGI spec defined in RFC 3875.
""" """
url_parts = urllib.parse.urlparse(self.url) url_parts = urllib.parse.urlparse(self.url)
client_addr = self.transport.getPeer()
environ = { environ = {
"GEMINI_URL": self.url, "GEMINI_URL": self.url,
"HOSTNAME": self.server.hostname, "HOSTNAME": self.server.hostname,
"PATH_INFO": url_parts.path, "PATH_INFO": url_parts.path,
"QUERY_STRING": url_parts.query, "QUERY_STRING": url_parts.query,
"REMOTE_ADDR": client_addr.host, "REMOTE_ADDR": self.client_addr.host,
"REMOTE_HOST": client_addr.host, "REMOTE_HOST": self.client_addr.host,
"SERVER_NAME": self.server.hostname, "SERVER_NAME": self.server.hostname,
"SERVER_PORT": str(client_addr.port), "SERVER_PORT": str(self.client_addr.port),
"SERVER_PROTOCOL": "GEMINI", "SERVER_PROTOCOL": "GEMINI",
"SERVER_SOFTWARE": f"jetforce/{__version__}", "SERVER_SOFTWARE": f"jetforce/{__version__}",
} }
if self.client_cert:
openssl_cert = self.transport.getPeerCertificate()
if openssl_cert:
# Extract useful information from the client certificate. These # Extract useful information from the client certificate. These
# mostly follow the naming convention from GLV-1.12556 # mostly follow the naming convention from GLV-1.12556
cert = openssl_cert.to_cryptography() cert = self.client_cert
name_attrs = cert.subject.get_attributes_for_oid(CN) name_attrs = cert.subject.get_attributes_for_oid(CN)
common_name = name_attrs[0].value if name_attrs else "" common_name = name_attrs[0].value if name_attrs else ""
fingerprint_bytes = cert.fingerprint(hashes.SHA256()) fingerprint_bytes = cert.fingerprint(hashes.SHA256())
@ -726,7 +729,6 @@ class GeminiProtocol(LineOnlyReceiver):
"TLS_CLIENT_SERIAL_NUMBER": cert.serial_number, "TLS_CLIENT_SERIAL_NUMBER": cert.serial_number,
} }
) )
return environ return environ
def parse_header(self) -> None: def parse_header(self) -> None:
@ -781,15 +783,20 @@ class GeminiProtocol(LineOnlyReceiver):
""" """
Log a gemini request using a format derived from the Common Log Format. Log a gemini request using a format derived from the Common Log Format.
""" """
message = '{} [{}] "{}" {} {} {}'.format( try:
self.transport.getPeer().host, message = '{} [{}] "{}" {} {} {}'.format(
time.strftime(self.TIMESTAMP_FORMAT, self.connected_timestamp), self.client_addr.host,
self.url, time.strftime(self.TIMESTAMP_FORMAT, self.connected_timestamp),
self.status, self.url,
self.meta, self.status,
self.response_size, self.meta,
) self.response_size,
self.server.log_message(message) )
except AttributeError:
# The connection ended before we got far enough to log anything
pass
else:
self.server.log_message(message)
class GeminiServer(Factory): class GeminiServer(Factory):
@ -873,13 +880,17 @@ class GeminiServer(Factory):
cafile=self.cafile, cafile=self.cafile,
capath=self.capath, capath=self.capath,
) )
endpoint = self.endpoint_class(
reactor=self.reactor, interfaces = [self.host] if self.host else ["0.0.0.0", "::"]
port=self.port, for interface in interfaces:
sslContextFactory=tls_context_factory, endpoint = self.endpoint_class(
interface=self.host, reactor=self.reactor,
) port=self.port,
endpoint.listen(self).addCallback(self.on_bind_interface) sslContextFactory=tls_context_factory,
interface=interface,
)
endpoint.listen(self).addCallback(self.on_bind_interface)
self.reactor.run() self.reactor.run()