diff --git a/jetforce.py b/jetforce.py index 12a37f2..569301f 100755 --- a/jetforce.py +++ b/jetforce.py @@ -57,6 +57,7 @@ from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import rsa from OpenSSL import SSL from twisted.internet import reactor +from twisted.internet.address import IPv4Address, IPv6Address from twisted.internet.base import ReactorBase from twisted.internet.endpoints import SSL4ServerEndpoint from twisted.internet.protocol import Factory @@ -626,6 +627,8 @@ class GeminiProtocol(LineOnlyReceiver): TIMESTAMP_FORMAT = "%d/%b/%Y:%H:%M:%S %z" + client_addr: typing.Union[IPv4Address, IPv6Address] + client_cert: typing.Optional[x509.Certificate] connected_timestamp: time.struct_time request: bytes url: str @@ -645,6 +648,12 @@ class GeminiProtocol(LineOnlyReceiver): self.connected_timestamp = time.localtime() self.response_size = 0 self.response_buffer = "" + self.client_addr = self.transport.getPeer() + self.client_cert = None + + peer_cert = self.transport.getPeerCertificate() + if peer_cert: + self.client_cert = peer_cert.to_cryptography() def lineReceived(self, line): """ @@ -656,25 +665,19 @@ class GeminiProtocol(LineOnlyReceiver): """ self.request = line try: - try: - self.handle() - finally: - self.flush_status() - try: - self.log_request() - except Exception: - # Malformed request or dropped connection - pass + self.handle_request() finally: + self.log_request() self.transport.loseConnection() - def handle(self): + def handle_request(self): try: self.parse_header() except Exception: # Malformed request, throw it away and exit immediately self.write_status(Status.BAD_REQUEST, "Malformed request") - return + self.flush_status() + raise try: environ = self.build_environ() @@ -683,6 +686,9 @@ class GeminiProtocol(LineOnlyReceiver): self.write_body(data) except Exception: self.write_status(Status.CGI_ERROR, "An unexpected error occurred") + raise + finally: + self.flush_status() def build_environ(self) -> typing.Dict[str, typing.Any]: """ @@ -691,25 +697,22 @@ class GeminiProtocol(LineOnlyReceiver): Variable names conform to the CGI spec defined in RFC 3875. """ url_parts = urllib.parse.urlparse(self.url) - client_addr = self.transport.getPeer() environ = { "GEMINI_URL": self.url, "HOSTNAME": self.server.hostname, "PATH_INFO": url_parts.path, "QUERY_STRING": url_parts.query, - "REMOTE_ADDR": client_addr.host, - "REMOTE_HOST": client_addr.host, + "REMOTE_ADDR": self.client_addr.host, + "REMOTE_HOST": self.client_addr.host, "SERVER_NAME": self.server.hostname, - "SERVER_PORT": str(client_addr.port), + "SERVER_PORT": str(self.client_addr.port), "SERVER_PROTOCOL": "GEMINI", "SERVER_SOFTWARE": f"jetforce/{__version__}", } - - openssl_cert = self.transport.getPeerCertificate() - if openssl_cert: + if self.client_cert: # Extract useful information from the client certificate. These # mostly follow the naming convention from GLV-1.12556 - cert = openssl_cert.to_cryptography() + cert = self.client_cert name_attrs = cert.subject.get_attributes_for_oid(CN) common_name = name_attrs[0].value if name_attrs else "" fingerprint_bytes = cert.fingerprint(hashes.SHA256()) @@ -726,7 +729,6 @@ class GeminiProtocol(LineOnlyReceiver): "TLS_CLIENT_SERIAL_NUMBER": cert.serial_number, } ) - return environ def parse_header(self) -> None: @@ -781,15 +783,20 @@ class GeminiProtocol(LineOnlyReceiver): """ Log a gemini request using a format derived from the Common Log Format. """ - message = '{} [{}] "{}" {} {} {}'.format( - self.transport.getPeer().host, - time.strftime(self.TIMESTAMP_FORMAT, self.connected_timestamp), - self.url, - self.status, - self.meta, - self.response_size, - ) - self.server.log_message(message) + try: + message = '{} [{}] "{}" {} {} {}'.format( + self.client_addr.host, + time.strftime(self.TIMESTAMP_FORMAT, self.connected_timestamp), + self.url, + self.status, + self.meta, + self.response_size, + ) + except AttributeError: + # The connection ended before we got far enough to log anything + pass + else: + self.server.log_message(message) class GeminiServer(Factory): @@ -873,13 +880,17 @@ class GeminiServer(Factory): cafile=self.cafile, capath=self.capath, ) - endpoint = self.endpoint_class( - reactor=self.reactor, - port=self.port, - sslContextFactory=tls_context_factory, - interface=self.host, - ) - endpoint.listen(self).addCallback(self.on_bind_interface) + + interfaces = [self.host] if self.host else ["0.0.0.0", "::"] + for interface in interfaces: + endpoint = self.endpoint_class( + reactor=self.reactor, + port=self.port, + sslContextFactory=tls_context_factory, + interface=interface, + ) + endpoint.listen(self).addCallback(self.on_bind_interface) + self.reactor.run()