diff --git a/examples/echo_server.py b/examples/echo_server.py index 57a6a32..1056887 100644 --- a/examples/echo_server.py +++ b/examples/echo_server.py @@ -14,11 +14,13 @@ def echo(environ, send_status): if __name__ == "__main__": args = jetforce.command_line_parser().parse_args() + ssl_context = jetforce.make_ssl_context( + args.hostname, args.certfile, args.keyfile, args.cafile, args.capath + ) server = jetforce.GeminiServer( host=args.host, port=args.port, - certfile=args.certfile, - keyfile=args.keyfile, + ssl_context=ssl_context, hostname=args.hostname, app=echo, ) diff --git a/examples/guestbook.py b/examples/guestbook.py index f2539d6..067926c 100644 --- a/examples/guestbook.py +++ b/examples/guestbook.py @@ -24,7 +24,6 @@ def index(request): for line in fp: line = line.strip() if line.startswith("=>"): - # Protect guests from writing messages that contain links data.append(line[2:]) else: data.append(line) @@ -40,6 +39,7 @@ def submit(request): created = datetime.utcnow() with guestbook.open("a") as fp: fp.write(f"\n[{created:%Y-%m-%d %I:%M %p}]\n{message}\n") + return Response(Status.REDIRECT_TEMPORARY, "") else: return Response(Status.INPUT, "Enter your message (max 256 characters)") @@ -47,11 +47,13 @@ def submit(request): if __name__ == "__main__": args = jetforce.command_line_parser().parse_args() + ssl_context = jetforce.make_ssl_context( + args.hostname, args.certfile, args.keyfile, args.cafile, args.capath + ) server = jetforce.GeminiServer( host=args.host, port=args.port, - certfile=args.certfile, - keyfile=args.keyfile, + ssl_context=ssl_context, hostname=args.hostname, app=app, ) diff --git a/examples/http_proxy.py b/examples/http_proxy.py index 836919d..1708c2c 100644 --- a/examples/http_proxy.py +++ b/examples/http_proxy.py @@ -27,11 +27,13 @@ def proxy_request(request): if __name__ == "__main__": args = jetforce.command_line_parser().parse_args() + ssl_context = jetforce.make_ssl_context( + args.hostname, args.certfile, args.keyfile, args.cafile, args.capath + ) server = jetforce.GeminiServer( host=args.host, port=args.port, - certfile=args.certfile, - keyfile=args.keyfile, + ssl_context=ssl_context, hostname=args.hostname, app=app, ) diff --git a/jetforce.py b/jetforce.py index cd354b0..5f216d9 100755 --- a/jetforce.py +++ b/jetforce.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3.7 +from __future__ import annotations + import argparse import asyncio import codecs @@ -83,7 +85,6 @@ class Request: def __init__(self, environ: dict): self.environ = environ - self.url = environ["GEMINI_URL"] url_parts = urllib.parse.urlparse(self.url) @@ -92,7 +93,7 @@ class Request: self.port = url_parts.port self.path = url_parts.path self.params = url_parts.params - self.query = url_parts.query + self.query = urllib.parse.unquote(url_parts.query) self.fragment = url_parts.fragment @@ -258,11 +259,17 @@ class StaticDirectoryApplication(JetforceApplication): return Response(Status.NOT_FOUND, "Not Found") def run_cgi_script(self, filesystem_path: pathlib.Path, environ: dict) -> Response: + """ + Execute the given file as a CGI script and return the script's stdout + stream to the client. + """ script_name = str(filesystem_path) cgi_env = environ.copy() cgi_env["GATEWAY_INTERFACE"] = "GCI/1.1" cgi_env["SCRIPT_NAME"] = script_name + # Decode the stream as unicode so we can parse the status line + # Use surrogateescape to preserve any non-UTF8 byte sequences. out = subprocess.Popen( [script_name], stdout=subprocess.PIPE, @@ -334,7 +341,7 @@ class GeminiRequestHandler: removed or slimmed-down. """ - def __init__(self, server: "GeminiServer", app: typing.Callable) -> None: + def __init__(self, server: GeminiServer, app: typing.Callable) -> None: self.server = server self.app = app self.reader: typing.Optional[asyncio.StreamReader] = None @@ -488,8 +495,7 @@ class GeminiServer: app: typing.Callable, host: str = "127.0.0.1", port: int = 1965, - certfile: typing.Optional[str] = None, - keyfile: typing.Optional[str] = None, + ssl_context: ssl.SSLContext = None, hostname: str = "localhost", ) -> None: @@ -497,14 +503,7 @@ class GeminiServer: self.port = port self.hostname = hostname self.app = app - - if not certfile: - certfile, keyfile = self.generate_tls_certificate(hostname) - - self.ssl_context = ssl.SSLContext() - self.ssl_context.verify_mode = ssl.CERT_OPTIONAL - self.ssl_context.check_hostname = False - self.ssl_context.load_cert_chain(certfile, keyfile) + self.ssl_context = ssl_context async def run(self) -> None: """ @@ -540,31 +539,71 @@ class GeminiServer: """ print(message, file=sys.stderr) - @staticmethod - def generate_tls_certificate(hostname: str) -> typing.Tuple[str, str]: - """ - Utility function to generate a self-signed SSL certificate key pair if - one isn't provided. Results may vary depending on your version of OpenSSL. - """ - certfile = pathlib.Path(tempfile.gettempdir()) / f"{hostname}.crt" - keyfile = pathlib.Path(tempfile.gettempdir()) / f"{hostname}.key" - if not certfile.exists() or not keyfile.exists(): - print(f"Writing ad hoc TLS certificate to {certfile}") - subprocess.run( - [ - f"openssl req -newkey rsa:2048 -nodes -keyout {keyfile}" - f' -nodes -x509 -out {certfile} -subj "/CN={hostname}"' - ], - shell=True, - check=True, - ) - return str(certfile), str(keyfile) + +def generate_ad_hoc_certificate(hostname: str) -> typing.Tuple[str, str]: + """ + Utility function to generate a self-signed SSL certificate key pair if + one isn't provided. Results may vary depending on your version of OpenSSL. + """ + certfile = pathlib.Path(tempfile.gettempdir()) / f"{hostname}.crt" + keyfile = pathlib.Path(tempfile.gettempdir()) / f"{hostname}.key" + if not certfile.exists() or not keyfile.exists(): + print(f"Writing ad hoc TLS certificate to {certfile}") + subprocess.run( + [ + f"openssl req -newkey rsa:2048 -nodes -keyout {keyfile}" + f' -nodes -x509 -out {certfile} -subj "/CN={hostname}"' + ], + shell=True, + check=True, + ) + return str(certfile), str(keyfile) + + +def make_ssl_context( + hostname: str = "localhost", + certfile: typing.Optional[str] = None, + keyfile: typing.Optional[str] = None, + cafile: typing.Optional[str] = None, + capath: typing.Optional[str] = None, +) -> ssl.SSLContext: + """ + Generate a sane default SSL context for a Gemini server. + + For more information on what these variables mean and what values they can + contain, see the python standard library documentation: + + https://docs.python.org/3/library/ssl.html#ssl-contexts + + verify_mode: ssl.CERT_OPTIONAL + A client certificate request is sent to the client. The client may + either ignore the request or send a certificate in order perform TLS + client cert authentication. If the client chooses to send a certificate, + it is verified. Any verification error immediately aborts the TLS + handshake. + """ + if certfile is None: + certfile, keyfile = generate_ad_hoc_certificate(hostname) + + context = ssl.SSLContext() + context.verify_mode = ssl.CERT_OPTIONAL + context.load_cert_chain(certfile, keyfile) + + if not cafile and not capath: + # Load from the system's default client CA directory + context.load_default_certs(purpose=ssl.Purpose.CLIENT_AUTH) + else: + # Use a custom CA for validating client certificates + context.load_verify_locations(cafile, capath) + + return context def command_line_parser() -> argparse.ArgumentParser: """ Construct the default argument parser when launching the server from - the command line. + the command line. These are meant to be application-agnostic arguments + that could apply to any subclass of the JetforceApplication. """ parser = argparse.ArgumentParser( prog="jetforce", @@ -574,6 +613,7 @@ def command_line_parser() -> argparse.ArgumentParser: ) parser.add_argument("--host", help="Server address to bind to", default="127.0.0.1") parser.add_argument("--port", help="Server port to bind to", type=int, default=1965) + parser.add_argument("--hostname", help="Server hostname", default="localhost") parser.add_argument( "--tls-certfile", dest="certfile", @@ -586,7 +626,18 @@ def command_line_parser() -> argparse.ArgumentParser: help="Server TLS private key file", metavar="FILE", ) - parser.add_argument("--hostname", help="Server hostname", default="localhost") + parser.add_argument( + "--tls-cafile", + dest="cafile", + help="A CA file to use for validating clients", + metavar="FILE", + ) + parser.add_argument( + "--tls-capath", + dest="capath", + help="A directory containing CA files for validating clients", + metavar="DIR", + ) return parser @@ -617,11 +668,13 @@ def run_server() -> None: args = parser.parse_args() app = StaticDirectoryApplication(args.dir, args.index_file, args.cgi_dir) + ssl_context = make_ssl_context( + args.hostname, args.certfile, args.keyfile, args.cafile, args.capath + ) server = GeminiServer( host=args.host, port=args.port, - certfile=args.certfile, - keyfile=args.keyfile, + ssl_context=ssl_context, hostname=args.hostname, app=app, )