Add commandline arguments for supplying a client CA

This commit is contained in:
Michael Lazar 2019-08-28 22:33:58 -04:00
parent 21737c425a
commit 62e1181e4b
4 changed files with 102 additions and 43 deletions

View File

@ -14,11 +14,13 @@ def echo(environ, send_status):
if __name__ == "__main__": if __name__ == "__main__":
args = jetforce.command_line_parser().parse_args() args = jetforce.command_line_parser().parse_args()
ssl_context = jetforce.make_ssl_context(
args.hostname, args.certfile, args.keyfile, args.cafile, args.capath
)
server = jetforce.GeminiServer( server = jetforce.GeminiServer(
host=args.host, host=args.host,
port=args.port, port=args.port,
certfile=args.certfile, ssl_context=ssl_context,
keyfile=args.keyfile,
hostname=args.hostname, hostname=args.hostname,
app=echo, app=echo,
) )

View File

@ -24,7 +24,6 @@ def index(request):
for line in fp: for line in fp:
line = line.strip() line = line.strip()
if line.startswith("=>"): if line.startswith("=>"):
# Protect guests from writing messages that contain links
data.append(line[2:]) data.append(line[2:])
else: else:
data.append(line) data.append(line)
@ -40,6 +39,7 @@ def submit(request):
created = datetime.utcnow() created = datetime.utcnow()
with guestbook.open("a") as fp: with guestbook.open("a") as fp:
fp.write(f"\n[{created:%Y-%m-%d %I:%M %p}]\n{message}\n") fp.write(f"\n[{created:%Y-%m-%d %I:%M %p}]\n{message}\n")
return Response(Status.REDIRECT_TEMPORARY, "") return Response(Status.REDIRECT_TEMPORARY, "")
else: else:
return Response(Status.INPUT, "Enter your message (max 256 characters)") return Response(Status.INPUT, "Enter your message (max 256 characters)")
@ -47,11 +47,13 @@ def submit(request):
if __name__ == "__main__": if __name__ == "__main__":
args = jetforce.command_line_parser().parse_args() args = jetforce.command_line_parser().parse_args()
ssl_context = jetforce.make_ssl_context(
args.hostname, args.certfile, args.keyfile, args.cafile, args.capath
)
server = jetforce.GeminiServer( server = jetforce.GeminiServer(
host=args.host, host=args.host,
port=args.port, port=args.port,
certfile=args.certfile, ssl_context=ssl_context,
keyfile=args.keyfile,
hostname=args.hostname, hostname=args.hostname,
app=app, app=app,
) )

View File

@ -27,11 +27,13 @@ def proxy_request(request):
if __name__ == "__main__": if __name__ == "__main__":
args = jetforce.command_line_parser().parse_args() args = jetforce.command_line_parser().parse_args()
ssl_context = jetforce.make_ssl_context(
args.hostname, args.certfile, args.keyfile, args.cafile, args.capath
)
server = jetforce.GeminiServer( server = jetforce.GeminiServer(
host=args.host, host=args.host,
port=args.port, port=args.port,
certfile=args.certfile, ssl_context=ssl_context,
keyfile=args.keyfile,
hostname=args.hostname, hostname=args.hostname,
app=app, app=app,
) )

View File

@ -1,4 +1,6 @@
#!/usr/bin/env python3.7 #!/usr/bin/env python3.7
from __future__ import annotations
import argparse import argparse
import asyncio import asyncio
import codecs import codecs
@ -83,7 +85,6 @@ class Request:
def __init__(self, environ: dict): def __init__(self, environ: dict):
self.environ = environ self.environ = environ
self.url = environ["GEMINI_URL"] self.url = environ["GEMINI_URL"]
url_parts = urllib.parse.urlparse(self.url) url_parts = urllib.parse.urlparse(self.url)
@ -92,7 +93,7 @@ class Request:
self.port = url_parts.port self.port = url_parts.port
self.path = url_parts.path self.path = url_parts.path
self.params = url_parts.params self.params = url_parts.params
self.query = url_parts.query self.query = urllib.parse.unquote(url_parts.query)
self.fragment = url_parts.fragment self.fragment = url_parts.fragment
@ -258,11 +259,17 @@ class StaticDirectoryApplication(JetforceApplication):
return Response(Status.NOT_FOUND, "Not Found") return Response(Status.NOT_FOUND, "Not Found")
def run_cgi_script(self, filesystem_path: pathlib.Path, environ: dict) -> Response: def run_cgi_script(self, filesystem_path: pathlib.Path, environ: dict) -> Response:
"""
Execute the given file as a CGI script and return the script's stdout
stream to the client.
"""
script_name = str(filesystem_path) script_name = str(filesystem_path)
cgi_env = environ.copy() cgi_env = environ.copy()
cgi_env["GATEWAY_INTERFACE"] = "GCI/1.1" cgi_env["GATEWAY_INTERFACE"] = "GCI/1.1"
cgi_env["SCRIPT_NAME"] = script_name cgi_env["SCRIPT_NAME"] = script_name
# Decode the stream as unicode so we can parse the status line
# Use surrogateescape to preserve any non-UTF8 byte sequences.
out = subprocess.Popen( out = subprocess.Popen(
[script_name], [script_name],
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
@ -334,7 +341,7 @@ class GeminiRequestHandler:
removed or slimmed-down. removed or slimmed-down.
""" """
def __init__(self, server: "GeminiServer", app: typing.Callable) -> None: def __init__(self, server: GeminiServer, app: typing.Callable) -> None:
self.server = server self.server = server
self.app = app self.app = app
self.reader: typing.Optional[asyncio.StreamReader] = None self.reader: typing.Optional[asyncio.StreamReader] = None
@ -488,8 +495,7 @@ class GeminiServer:
app: typing.Callable, app: typing.Callable,
host: str = "127.0.0.1", host: str = "127.0.0.1",
port: int = 1965, port: int = 1965,
certfile: typing.Optional[str] = None, ssl_context: ssl.SSLContext = None,
keyfile: typing.Optional[str] = None,
hostname: str = "localhost", hostname: str = "localhost",
) -> None: ) -> None:
@ -497,14 +503,7 @@ class GeminiServer:
self.port = port self.port = port
self.hostname = hostname self.hostname = hostname
self.app = app self.app = app
self.ssl_context = ssl_context
if not certfile:
certfile, keyfile = self.generate_tls_certificate(hostname)
self.ssl_context = ssl.SSLContext()
self.ssl_context.verify_mode = ssl.CERT_OPTIONAL
self.ssl_context.check_hostname = False
self.ssl_context.load_cert_chain(certfile, keyfile)
async def run(self) -> None: async def run(self) -> None:
""" """
@ -540,31 +539,71 @@ class GeminiServer:
""" """
print(message, file=sys.stderr) print(message, file=sys.stderr)
@staticmethod
def generate_tls_certificate(hostname: str) -> typing.Tuple[str, str]: def generate_ad_hoc_certificate(hostname: str) -> typing.Tuple[str, str]:
""" """
Utility function to generate a self-signed SSL certificate key pair if Utility function to generate a self-signed SSL certificate key pair if
one isn't provided. Results may vary depending on your version of OpenSSL. one isn't provided. Results may vary depending on your version of OpenSSL.
""" """
certfile = pathlib.Path(tempfile.gettempdir()) / f"{hostname}.crt" certfile = pathlib.Path(tempfile.gettempdir()) / f"{hostname}.crt"
keyfile = pathlib.Path(tempfile.gettempdir()) / f"{hostname}.key" keyfile = pathlib.Path(tempfile.gettempdir()) / f"{hostname}.key"
if not certfile.exists() or not keyfile.exists(): if not certfile.exists() or not keyfile.exists():
print(f"Writing ad hoc TLS certificate to {certfile}") print(f"Writing ad hoc TLS certificate to {certfile}")
subprocess.run( subprocess.run(
[ [
f"openssl req -newkey rsa:2048 -nodes -keyout {keyfile}" f"openssl req -newkey rsa:2048 -nodes -keyout {keyfile}"
f' -nodes -x509 -out {certfile} -subj "/CN={hostname}"' f' -nodes -x509 -out {certfile} -subj "/CN={hostname}"'
], ],
shell=True, shell=True,
check=True, check=True,
) )
return str(certfile), str(keyfile) return str(certfile), str(keyfile)
def make_ssl_context(
hostname: str = "localhost",
certfile: typing.Optional[str] = None,
keyfile: typing.Optional[str] = None,
cafile: typing.Optional[str] = None,
capath: typing.Optional[str] = None,
) -> ssl.SSLContext:
"""
Generate a sane default SSL context for a Gemini server.
For more information on what these variables mean and what values they can
contain, see the python standard library documentation:
https://docs.python.org/3/library/ssl.html#ssl-contexts
verify_mode: ssl.CERT_OPTIONAL
A client certificate request is sent to the client. The client may
either ignore the request or send a certificate in order perform TLS
client cert authentication. If the client chooses to send a certificate,
it is verified. Any verification error immediately aborts the TLS
handshake.
"""
if certfile is None:
certfile, keyfile = generate_ad_hoc_certificate(hostname)
context = ssl.SSLContext()
context.verify_mode = ssl.CERT_OPTIONAL
context.load_cert_chain(certfile, keyfile)
if not cafile and not capath:
# Load from the system's default client CA directory
context.load_default_certs(purpose=ssl.Purpose.CLIENT_AUTH)
else:
# Use a custom CA for validating client certificates
context.load_verify_locations(cafile, capath)
return context
def command_line_parser() -> argparse.ArgumentParser: def command_line_parser() -> argparse.ArgumentParser:
""" """
Construct the default argument parser when launching the server from Construct the default argument parser when launching the server from
the command line. the command line. These are meant to be application-agnostic arguments
that could apply to any subclass of the JetforceApplication.
""" """
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
prog="jetforce", prog="jetforce",
@ -574,6 +613,7 @@ def command_line_parser() -> argparse.ArgumentParser:
) )
parser.add_argument("--host", help="Server address to bind to", default="127.0.0.1") parser.add_argument("--host", help="Server address to bind to", default="127.0.0.1")
parser.add_argument("--port", help="Server port to bind to", type=int, default=1965) parser.add_argument("--port", help="Server port to bind to", type=int, default=1965)
parser.add_argument("--hostname", help="Server hostname", default="localhost")
parser.add_argument( parser.add_argument(
"--tls-certfile", "--tls-certfile",
dest="certfile", dest="certfile",
@ -586,7 +626,18 @@ def command_line_parser() -> argparse.ArgumentParser:
help="Server TLS private key file", help="Server TLS private key file",
metavar="FILE", metavar="FILE",
) )
parser.add_argument("--hostname", help="Server hostname", default="localhost") parser.add_argument(
"--tls-cafile",
dest="cafile",
help="A CA file to use for validating clients",
metavar="FILE",
)
parser.add_argument(
"--tls-capath",
dest="capath",
help="A directory containing CA files for validating clients",
metavar="DIR",
)
return parser return parser
@ -617,11 +668,13 @@ def run_server() -> None:
args = parser.parse_args() args = parser.parse_args()
app = StaticDirectoryApplication(args.dir, args.index_file, args.cgi_dir) app = StaticDirectoryApplication(args.dir, args.index_file, args.cgi_dir)
ssl_context = make_ssl_context(
args.hostname, args.certfile, args.keyfile, args.cafile, args.capath
)
server = GeminiServer( server = GeminiServer(
host=args.host, host=args.host,
port=args.port, port=args.port,
certfile=args.certfile, ssl_context=ssl_context,
keyfile=args.keyfile,
hostname=args.hostname, hostname=args.hostname,
app=app, app=app,
) )