From 135dbda8784e8e37e28969a1409cc30df94260bf Mon Sep 17 00:00:00 2001 From: Michael Lazar Date: Tue, 29 Dec 2020 00:03:07 -0500 Subject: [PATCH] Type hints (#49) * Better type hint coverage * linting fixes * Keep the old ResponseType name for backwards compatibility * Remove unnecessary assert statement Co-authored-by: Michael Lazar --- CHANGELOG.md | 7 +++-- README.md | 4 +-- examples/cgi/cowsay.cgi | 2 +- examples/counter.py | 4 +-- jetforce/__main__.py | 2 +- jetforce/app/base.py | 64 ++++++++++++++++++++++++++++----------- jetforce/app/composite.py | 17 ++++++++--- jetforce/app/static.py | 14 ++++++--- jetforce/protocol.py | 23 +++++++------- jetforce/server.py | 9 +++--- jetforce/tls.py | 8 +++-- jetforce_client.py | 11 +++++-- setup.py | 2 +- 13 files changed, 112 insertions(+), 55 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 16276fb..d17e62a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,10 +6,13 @@ - Added support for international domain names using IDN encoding. +#### New Features + +- Several fixes & improvements to python type hinting coverage. +- Added a ``py.typed`` file to indicate project support for type hints. + #### Bug Fixes -- Added py.typed file to indicate that the jetforce python library has support - for type hints. - Fixed a bug where TLS_CLIENT_AUTHORISED would sometimes be set to ``True``/``False`` instead of ``1``/``0``. diff --git a/README.md b/README.md index 98d0e82..8ef4e4c 100644 --- a/README.md +++ b/README.md @@ -190,7 +190,7 @@ additional modification by the server.
GATEWAY_INTERFACE
- CGI version (for compatability with RFC 3785).
+ CGI version (for compatibility with RFC 3785).
Example: "CGI/1.1"
@@ -270,7 +270,7 @@ Additional CGI variables will be included only when the client connection uses a
AUTH_TYPE
- Authentication type (for compatability with RFC 3785).
+ Authentication type (for compatibility with RFC 3785).
Example: "CERTIFICATE"
diff --git a/examples/cgi/cowsay.cgi b/examples/cgi/cowsay.cgi index 07b6aaa..1a2cf61 100755 --- a/examples/cgi/cowsay.cgi +++ b/examples/cgi/cowsay.cgi @@ -1,5 +1,5 @@ #!/usr/local/bin/python3.7 -""" +r""" CGI script that requests user supplied text using the INPUT status, and pipes it into the `cowsay` program. diff --git a/examples/counter.py b/examples/counter.py index 16ca794..786daf0 100644 --- a/examples/counter.py +++ b/examples/counter.py @@ -59,8 +59,8 @@ def deferred_counter(): eventually run in the main event loop. """ - def delayed_callback(x): - return f"{x}\r\n" + def delayed_callback(var): + return f"{var}\r\n" for x in range(10): yield deferLater(reactor, 1, delayed_callback, x) diff --git a/jetforce/__main__.py b/jetforce/__main__.py index 1c1e775..6f03e14 100644 --- a/jetforce/__main__.py +++ b/jetforce/__main__.py @@ -106,7 +106,7 @@ group.add_argument( ) -def main(): +def main() -> None: args = parser.parse_args() rate_limiter = RateLimiter(args.rate_limit) if args.rate_limit else None app = StaticDirectoryApplication( diff --git a/jetforce/app/base.py b/jetforce/app/base.py index f9651ec..b147bb4 100644 --- a/jetforce/app/base.py +++ b/jetforce/app/base.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import dataclasses import re import time @@ -7,7 +9,13 @@ from urllib.parse import unquote, urlparse from twisted.internet.defer import Deferred +EnvironDict = typing.Dict[str, object] ResponseType = typing.Union[str, bytes, Deferred] +ApplicationResponse = typing.Iterable[ResponseType] +WriteStatusCallable = typing.Callable[[int, str], None] +ApplicationCallable = typing.Callable[ + [EnvironDict, WriteStatusCallable], ApplicationResponse +] class Status: @@ -45,9 +53,19 @@ class Request: Object that encapsulates information about a single gemini request. """ - def __init__(self, environ: dict): + environ: EnvironDict + url: str + scheme: str + hostname: str + port: typing.Optional[int] + path: str + params: str + query: str + fragment: str + + def __init__(self, environ: EnvironDict): self.environ = environ - self.url = environ["GEMINI_URL"] + self.url = typing.cast(str, environ["GEMINI_URL"]) url_parts = urlparse(self.url) if not url_parts.hostname: @@ -84,7 +102,10 @@ class Response: status: int meta: str - body: typing.Union[None, ResponseType, typing.Iterable[ResponseType]] = None + body: typing.Union[None, ResponseType, ApplicationResponse] = None + + +RouteHandler = typing.Callable[..., Response] @dataclasses.dataclass @@ -101,7 +122,7 @@ class RoutePattern: strict_port: bool = True strict_trailing_slash: bool = False - def match(self, request: Request) -> typing.Optional[re.Match]: + def match(self, request: Request) -> typing.Optional[re.Match[str]]: """ Check if the given request URL matches this route pattern. """ @@ -112,12 +133,12 @@ class RoutePattern: server_port = request.environ["SERVER_PORT"] if self.strict_hostname and request.hostname != server_hostname: - return + return None if self.strict_port and request.port is not None: if request.port != server_port: - return + return None if self.scheme and self.scheme != request.scheme: - return + return None if self.strict_trailing_slash: request_path = request.path @@ -127,9 +148,6 @@ class RoutePattern: return re.fullmatch(self.path, request_path) -RouteHandler = typing.Callable[..., Response] - - class RateLimiter: """ A class that can be used to apply rate-limiting to endpoints. @@ -144,6 +162,11 @@ class RateLimiter: RE = re.compile("(?P[0-9]+)/(?P[0-9]+)?(?P[smhd])") + number: int + period: int + next_timestamp: float + rate_counter: typing.Dict[typing.Any, int] + def __init__(self, rate: str) -> None: match = self.RE.fullmatch(rate) if not match: @@ -166,7 +189,7 @@ class RateLimiter: self.next_timestamp = time.time() + self.period self.rate_counter = defaultdict(int) - def get_key(self, request: Request) -> typing.Optional[str]: + def get_key(self, request: Request) -> typing.Any: """ Rate limit based on the client's IP-address. """ @@ -190,6 +213,8 @@ class RateLimiter: msg = f"Rate limit exceeded, wait {time_left:.0f} seconds." return Response(Status.SLOW_DOWN, msg) + return None + def apply(self, wrapped_func: RouteHandler) -> RouteHandler: """ Decorator to apply rate limiting to an individual application route. @@ -203,7 +228,7 @@ class RateLimiter: return Response(Status.SUCCESS, "text/gemini", "hello world!") """ - def wrapper(request: Request, **kwargs) -> Response: + def wrapper(request: Request, **kwargs: typing.Any) -> Response: response = self.check(request) if response: return response @@ -224,13 +249,16 @@ class JetforceApplication: how to accomplish this. """ + rate_limiter: typing.Optional[RateLimiter] + routes: typing.List[typing.Tuple[RoutePattern, RouteHandler]] + def __init__(self, rate_limiter: typing.Optional[RateLimiter] = None): self.rate_limiter = rate_limiter - self.routes: typing.List[typing.Tuple[RoutePattern, RouteHandler]] = [] + self.routes = [] def __call__( - self, environ: dict, send_status: typing.Callable - ) -> typing.Iterator[ResponseType]: + self, environ: EnvironDict, send_status: WriteStatusCallable + ) -> ApplicationResponse: try: request = Request(environ) except Exception: @@ -245,7 +273,7 @@ class JetforceApplication: for route_pattern, callback in self.routes[::-1]: match = route_pattern.match(request) - if route_pattern.match(request): + if match: callback_kwargs = match.groupdict() break else: @@ -267,7 +295,7 @@ class JetforceApplication: hostname: typing.Optional[str] = None, strict_hostname: bool = True, strict_trailing_slash: bool = False, - ) -> typing.Callable: + ) -> typing.Callable[[RouteHandler], RouteHandler]: """ Decorator for binding a function to a route based on the URL path. @@ -287,7 +315,7 @@ class JetforceApplication: return wrap - def default_callback(self, request: Request, **_) -> Response: + def default_callback(self, request: Request, **_: typing.Any) -> Response: """ Set the error response based on the URL type. """ diff --git a/jetforce/app/composite.py b/jetforce/app/composite.py index 94940de..dccc7d5 100644 --- a/jetforce/app/composite.py +++ b/jetforce/app/composite.py @@ -1,6 +1,15 @@ import typing -from .base import Request, ResponseType, Status +from .base import ( + ApplicationCallable, + ApplicationResponse, + EnvironDict, + Request, + Status, + WriteStatusCallable, +) + +ApplicationMap = typing.Dict[typing.Optional[str], ApplicationCallable] class CompositeApplication: @@ -11,7 +20,7 @@ class CompositeApplication: two or more applications behind a single jetforce server. """ - def __init__(self, application_map: typing.Dict[typing.Optional[str], typing.Any]): + def __init__(self, application_map: ApplicationMap): """ Initialize the application by providing a mapping of hostname -> app key pairs. A hostname of `None` is a special key that can be used as @@ -29,8 +38,8 @@ class CompositeApplication: self.application_map = application_map def __call__( - self, environ: dict, send_status: typing.Callable - ) -> typing.Iterator[ResponseType]: + self, environ: EnvironDict, send_status: WriteStatusCallable + ) -> ApplicationResponse: try: request = Request(environ) except Exception: diff --git a/jetforce/app/static.py b/jetforce/app/static.py index dbb6cbe..5ff3097 100644 --- a/jetforce/app/static.py +++ b/jetforce/app/static.py @@ -7,6 +7,7 @@ import typing import urllib.parse from .base import ( + EnvironDict, JetforceApplication, RateLimiter, Request, @@ -33,6 +34,8 @@ class StaticDirectoryApplication(JetforceApplication): # Chunk size for streaming files, taken from the twisted FileSender class CHUNK_SIZE = 2 ** 14 + mimetypes: mimetypes.MimeTypes + def __init__( self, root_directory: str = "/var/gemini", @@ -57,8 +60,9 @@ class StaticDirectoryApplication(JetforceApplication): if os.path.isfile(fn): self.mimetypes.read(fn) - self.mimetypes.add_type("text/gemini", ".gmi") - self.mimetypes.add_type("text/gemini", ".gemini") + # This is a valid method but the type stubs are incorrect + self.mimetypes.add_type("text/gemini", ".gmi") # type: ignore + self.mimetypes.add_type("text/gemini", ".gemini") # type: ignore def serve_static_file(self, request: Request) -> Response: """ @@ -143,7 +147,9 @@ class StaticDirectoryApplication(JetforceApplication): else: return Response(Status.NOT_FOUND, "Not Found") - def run_cgi_script(self, filesystem_path: pathlib.Path, environ: dict) -> Response: + def run_cgi_script( + self, filesystem_path: typing.Union[str, pathlib.Path], environ: EnvironDict + ) -> Response: """ Execute the given file as a CGI script and return the script's stdout stream to the client. @@ -224,7 +230,7 @@ class StaticDirectoryApplication(JetforceApplication): meta += f"; lang={self.default_lang}" return meta - def default_callback(self, request: Request, **_) -> Response: + def default_callback(self, request: Request, **_: typing.Any) -> Response: """ Since the StaticDirectoryApplication only serves gemini URLs, return a proxy request refused for suspicious URLs. diff --git a/jetforce/protocol.py b/jetforce/protocol.py index 1a65265..99b36a9 100644 --- a/jetforce/protocol.py +++ b/jetforce/protocol.py @@ -11,9 +11,10 @@ from twisted.internet.error import ConnectionClosed from twisted.internet.protocol import connectionDone from twisted.internet.task import deferLater from twisted.protocols.basic import LineOnlyReceiver +from twisted.python.failure import Failure from .__version__ import __version__ -from .app.base import JetforceApplication, Status +from .app.base import ApplicationCallable, EnvironDict, Status from .tls import inspect_certificate if typing.TYPE_CHECKING: @@ -49,12 +50,12 @@ class GeminiProtocol(LineOnlyReceiver): response_buffer: str response_size: int - def __init__(self, server: GeminiServer, app: JetforceApplication): + def __init__(self, server: GeminiServer, app: ApplicationCallable): self.server = server self.app = app self._currently_deferred: typing.Optional[Deferred] = None - def connectionMade(self): + def connectionMade(self) -> None: """ This is invoked by twisted after the connection is first established. """ @@ -63,7 +64,7 @@ class GeminiProtocol(LineOnlyReceiver): self.response_buffer = "" self.client_addr = self.transport.getPeer() - def connectionLost(self, reason=connectionDone): + def connectionLost(self, reason: Failure = connectionDone) -> None: """ This is invoked by twisted after the connection has been closed. """ @@ -71,20 +72,20 @@ class GeminiProtocol(LineOnlyReceiver): self._currently_deferred.errback(reason) self._currently_deferred = None - def lineReceived(self, line): + def lineReceived(self, line: bytes) -> Deferred: """ This method is invoked by LineOnlyReceiver for every incoming line. """ self.request = line return ensureDeferred(self._handle_request_noblock()) - def lineLengthExceeded(self, line): + def lineLengthExceeded(self, line: bytes) -> None: """ Called when the maximum line length has been reached. """ - return self.finish_connection() + self.finish_connection() - def finish_connection(self): + def finish_connection(self) -> None: """ Send the TLS "close_notify" alert and then immediately close the TCP connection without waiting for the client to respond with it's own @@ -111,7 +112,7 @@ class GeminiProtocol(LineOnlyReceiver): # part of the above TLS shutdown. self.transport.transport.loseConnection() - async def _handle_request_noblock(self): + async def _handle_request_noblock(self) -> None: """ Handle the gemini request and write the raw response to the socket. @@ -170,14 +171,14 @@ class GeminiProtocol(LineOnlyReceiver): self.log_request() self.finish_connection() - async def track_deferred(self, deferred: Deferred): + async def track_deferred(self, deferred: Deferred) -> typing.Union[str, bytes]: self._currently_deferred = deferred try: return await deferred finally: self._currently_deferred = None - def build_environ(self) -> typing.Dict[str, typing.Any]: + def build_environ(self) -> EnvironDict: """ Construct a dictionary that will be passed to the application handler. diff --git a/jetforce/server.py b/jetforce/server.py index 092ac99..41916f3 100644 --- a/jetforce/server.py +++ b/jetforce/server.py @@ -4,13 +4,14 @@ import socket import sys import typing -from twisted.internet import reactor +from twisted.internet import reactor as _reactor from twisted.internet.base import ReactorBase from twisted.internet.endpoints import SSL4ServerEndpoint from twisted.internet.protocol import Factory from twisted.internet.tcp import Port from .__version__ import __version__ +from .app.base import ApplicationCallable from .protocol import GeminiProtocol from .tls import GeminiCertificateOptions, generate_ad_hoc_certificate @@ -49,8 +50,8 @@ class GeminiServer(Factory): def __init__( self, - app: typing.Callable, - reactor: ReactorBase = reactor, + app: ApplicationCallable, + reactor: ReactorBase = _reactor, host: str = "127.0.0.1", port: int = 1965, hostname: str = "localhost", @@ -95,7 +96,7 @@ class GeminiServer(Factory): else: self.log_message(f"Listening on [{sock_ip}]:{sock_port}") - def buildProtocol(self, addr) -> GeminiProtocol: + def buildProtocol(self, addr: typing.Any) -> GeminiProtocol: """ This method is invoked by twisted once for every incoming connection. diff --git a/jetforce/tls.py b/jetforce/tls.py index 745b320..b8315a8 100644 --- a/jetforce/tls.py +++ b/jetforce/tls.py @@ -15,7 +15,7 @@ from twisted.python.randbytes import secureRandom COMMON_NAME = x509.NameOID.COMMON_NAME -def inspect_certificate(cert: x509) -> dict: +def inspect_certificate(cert: x509.Certificate) -> typing.Dict[str, object]: """ Extract useful fields from a x509 client certificate object. """ @@ -66,7 +66,7 @@ def generate_ad_hoc_certificate(hostname: str) -> typing.Tuple[str, str]: subject_name = x509.Name([common_name]) not_valid_before = datetime.datetime.utcnow() not_valid_after = not_valid_before + datetime.timedelta(days=365) - certificate = x509.CertificateBuilder( + cert_builder = x509.CertificateBuilder( subject_name=subject_name, issuer_name=subject_name, public_key=private_key.public_key(), @@ -74,7 +74,7 @@ def generate_ad_hoc_certificate(hostname: str) -> typing.Tuple[str, str]: not_valid_before=not_valid_before, not_valid_after=not_valid_after, ) - certificate = certificate.sign(private_key, hashes.SHA256(), backend) + certificate = cert_builder.sign(private_key, hashes.SHA256(), backend) with open(certfile, "wb") as fp: # noinspection PyTypeChecker cert_data = certificate.public_bytes(serialization.Encoding.PEM) @@ -96,6 +96,8 @@ class GeminiCertificateOptions(CertificateOptions): https://github.com/twisted/twisted/blob/trunk/src/twisted/internet/_sslverify.py """ + _acceptableProtocols: typing.List[bytes] + def verify_callback( self, conn: OpenSSL.SSL.Connection, diff --git a/jetforce_client.py b/jetforce_client.py index be0654f..8e8ecf7 100755 --- a/jetforce_client.py +++ b/jetforce_client.py @@ -6,6 +6,7 @@ import argparse import socket import ssl import sys +import typing import urllib.parse context = ssl.create_default_context() @@ -13,7 +14,12 @@ context.check_hostname = False context.verify_mode = ssl.CERT_NONE -def fetch(url, host=None, port=None, use_sni=False): +def fetch( + url: str, + host: typing.Optional[str] = None, + port: typing.Optional[int] = None, + use_sni: bool = False, +) -> None: parsed_url = urllib.parse.urlparse(url) if not parsed_url.scheme: parsed_url = urllib.parse.urlparse(f"gemini://{url}") @@ -38,7 +44,7 @@ def fetch(url, host=None, port=None, use_sni=False): # ssock.unwrap() -def run_client(): +def run_client() -> None: # fmt: off parser = argparse.ArgumentParser(description="A simple gemini client") parser.add_argument("url") @@ -59,6 +65,7 @@ def run_client(): context.set_alpn_protocols([args.tls_alpn_protocol]) if args.tls_keylog: + # This is a "private" variable that the stdlib exposes for debugging context.keylog_filename = args.tls_keylog fetch(args.url, args.host, args.port, args.tls_enable_sni) diff --git a/setup.py b/setup.py index 899956b..b9efcc7 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ import codecs import setuptools -def long_description(): +def long_description() -> str: with codecs.open("README.md", encoding="utf8") as f: return f.read()