Type hints (#49)

* Better type hint coverage

* linting fixes

* Keep the old ResponseType name for backwards compatibility

* Remove unnecessary assert statement

Co-authored-by: Michael Lazar <mlazar@doctorondemand.com>
This commit is contained in:
Michael Lazar 2020-12-29 00:03:07 -05:00 committed by GitHub
parent 637025c8c3
commit 135dbda878
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 112 additions and 55 deletions

View File

@ -6,10 +6,13 @@
- Added support for international domain names using IDN encoding. - Added support for international domain names using IDN encoding.
#### New Features
- Several fixes & improvements to python type hinting coverage.
- Added a ``py.typed`` file to indicate project support for type hints.
#### Bug Fixes #### Bug Fixes
- Added py.typed file to indicate that the jetforce python library has support
for type hints.
- Fixed a bug where TLS_CLIENT_AUTHORISED would sometimes be set to - Fixed a bug where TLS_CLIENT_AUTHORISED would sometimes be set to
``True``/``False`` instead of ``1``/``0``. ``True``/``False`` instead of ``1``/``0``.

View File

@ -190,7 +190,7 @@ additional modification by the server.
<dl> <dl>
<dt>GATEWAY_INTERFACE</dt> <dt>GATEWAY_INTERFACE</dt>
<dd> <dd>
CGI version (for compatability with RFC 3785).<br> CGI version (for compatibility with RFC 3785).<br>
<em>Example: "CGI/1.1"</em> <em>Example: "CGI/1.1"</em>
</dd> </dd>
@ -270,7 +270,7 @@ Additional CGI variables will be included only when the client connection uses a
<dt>AUTH_TYPE</dt> <dt>AUTH_TYPE</dt>
<dd> <dd>
Authentication type (for compatability with RFC 3785).<br> Authentication type (for compatibility with RFC 3785).<br>
<em>Example: "CERTIFICATE"</em> <em>Example: "CERTIFICATE"</em>
</dd> </dd>

View File

@ -1,5 +1,5 @@
#!/usr/local/bin/python3.7 #!/usr/local/bin/python3.7
""" r"""
CGI script that requests user supplied text using the INPUT status, and CGI script that requests user supplied text using the INPUT status, and
pipes it into the `cowsay` program. pipes it into the `cowsay` program.

View File

@ -59,8 +59,8 @@ def deferred_counter():
eventually run in the main event loop. eventually run in the main event loop.
""" """
def delayed_callback(x): def delayed_callback(var):
return f"{x}\r\n" return f"{var}\r\n"
for x in range(10): for x in range(10):
yield deferLater(reactor, 1, delayed_callback, x) yield deferLater(reactor, 1, delayed_callback, x)

View File

@ -106,7 +106,7 @@ group.add_argument(
) )
def main(): def main() -> None:
args = parser.parse_args() args = parser.parse_args()
rate_limiter = RateLimiter(args.rate_limit) if args.rate_limit else None rate_limiter = RateLimiter(args.rate_limit) if args.rate_limit else None
app = StaticDirectoryApplication( app = StaticDirectoryApplication(

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import dataclasses import dataclasses
import re import re
import time import time
@ -7,7 +9,13 @@ from urllib.parse import unquote, urlparse
from twisted.internet.defer import Deferred from twisted.internet.defer import Deferred
EnvironDict = typing.Dict[str, object]
ResponseType = typing.Union[str, bytes, Deferred] ResponseType = typing.Union[str, bytes, Deferred]
ApplicationResponse = typing.Iterable[ResponseType]
WriteStatusCallable = typing.Callable[[int, str], None]
ApplicationCallable = typing.Callable[
[EnvironDict, WriteStatusCallable], ApplicationResponse
]
class Status: class Status:
@ -45,9 +53,19 @@ class Request:
Object that encapsulates information about a single gemini request. Object that encapsulates information about a single gemini request.
""" """
def __init__(self, environ: dict): environ: EnvironDict
url: str
scheme: str
hostname: str
port: typing.Optional[int]
path: str
params: str
query: str
fragment: str
def __init__(self, environ: EnvironDict):
self.environ = environ self.environ = environ
self.url = environ["GEMINI_URL"] self.url = typing.cast(str, environ["GEMINI_URL"])
url_parts = urlparse(self.url) url_parts = urlparse(self.url)
if not url_parts.hostname: if not url_parts.hostname:
@ -84,7 +102,10 @@ class Response:
status: int status: int
meta: str meta: str
body: typing.Union[None, ResponseType, typing.Iterable[ResponseType]] = None body: typing.Union[None, ResponseType, ApplicationResponse] = None
RouteHandler = typing.Callable[..., Response]
@dataclasses.dataclass @dataclasses.dataclass
@ -101,7 +122,7 @@ class RoutePattern:
strict_port: bool = True strict_port: bool = True
strict_trailing_slash: bool = False strict_trailing_slash: bool = False
def match(self, request: Request) -> typing.Optional[re.Match]: def match(self, request: Request) -> typing.Optional[re.Match[str]]:
""" """
Check if the given request URL matches this route pattern. Check if the given request URL matches this route pattern.
""" """
@ -112,12 +133,12 @@ class RoutePattern:
server_port = request.environ["SERVER_PORT"] server_port = request.environ["SERVER_PORT"]
if self.strict_hostname and request.hostname != server_hostname: if self.strict_hostname and request.hostname != server_hostname:
return return None
if self.strict_port and request.port is not None: if self.strict_port and request.port is not None:
if request.port != server_port: if request.port != server_port:
return return None
if self.scheme and self.scheme != request.scheme: if self.scheme and self.scheme != request.scheme:
return return None
if self.strict_trailing_slash: if self.strict_trailing_slash:
request_path = request.path request_path = request.path
@ -127,9 +148,6 @@ class RoutePattern:
return re.fullmatch(self.path, request_path) return re.fullmatch(self.path, request_path)
RouteHandler = typing.Callable[..., Response]
class RateLimiter: class RateLimiter:
""" """
A class that can be used to apply rate-limiting to endpoints. A class that can be used to apply rate-limiting to endpoints.
@ -144,6 +162,11 @@ class RateLimiter:
RE = re.compile("(?P<number>[0-9]+)/(?P<period>[0-9]+)?(?P<unit>[smhd])") RE = re.compile("(?P<number>[0-9]+)/(?P<period>[0-9]+)?(?P<unit>[smhd])")
number: int
period: int
next_timestamp: float
rate_counter: typing.Dict[typing.Any, int]
def __init__(self, rate: str) -> None: def __init__(self, rate: str) -> None:
match = self.RE.fullmatch(rate) match = self.RE.fullmatch(rate)
if not match: if not match:
@ -166,7 +189,7 @@ class RateLimiter:
self.next_timestamp = time.time() + self.period self.next_timestamp = time.time() + self.period
self.rate_counter = defaultdict(int) self.rate_counter = defaultdict(int)
def get_key(self, request: Request) -> typing.Optional[str]: def get_key(self, request: Request) -> typing.Any:
""" """
Rate limit based on the client's IP-address. Rate limit based on the client's IP-address.
""" """
@ -190,6 +213,8 @@ class RateLimiter:
msg = f"Rate limit exceeded, wait {time_left:.0f} seconds." msg = f"Rate limit exceeded, wait {time_left:.0f} seconds."
return Response(Status.SLOW_DOWN, msg) return Response(Status.SLOW_DOWN, msg)
return None
def apply(self, wrapped_func: RouteHandler) -> RouteHandler: def apply(self, wrapped_func: RouteHandler) -> RouteHandler:
""" """
Decorator to apply rate limiting to an individual application route. Decorator to apply rate limiting to an individual application route.
@ -203,7 +228,7 @@ class RateLimiter:
return Response(Status.SUCCESS, "text/gemini", "hello world!") return Response(Status.SUCCESS, "text/gemini", "hello world!")
""" """
def wrapper(request: Request, **kwargs) -> Response: def wrapper(request: Request, **kwargs: typing.Any) -> Response:
response = self.check(request) response = self.check(request)
if response: if response:
return response return response
@ -224,13 +249,16 @@ class JetforceApplication:
how to accomplish this. how to accomplish this.
""" """
rate_limiter: typing.Optional[RateLimiter]
routes: typing.List[typing.Tuple[RoutePattern, RouteHandler]]
def __init__(self, rate_limiter: typing.Optional[RateLimiter] = None): def __init__(self, rate_limiter: typing.Optional[RateLimiter] = None):
self.rate_limiter = rate_limiter self.rate_limiter = rate_limiter
self.routes: typing.List[typing.Tuple[RoutePattern, RouteHandler]] = [] self.routes = []
def __call__( def __call__(
self, environ: dict, send_status: typing.Callable self, environ: EnvironDict, send_status: WriteStatusCallable
) -> typing.Iterator[ResponseType]: ) -> ApplicationResponse:
try: try:
request = Request(environ) request = Request(environ)
except Exception: except Exception:
@ -245,7 +273,7 @@ class JetforceApplication:
for route_pattern, callback in self.routes[::-1]: for route_pattern, callback in self.routes[::-1]:
match = route_pattern.match(request) match = route_pattern.match(request)
if route_pattern.match(request): if match:
callback_kwargs = match.groupdict() callback_kwargs = match.groupdict()
break break
else: else:
@ -267,7 +295,7 @@ class JetforceApplication:
hostname: typing.Optional[str] = None, hostname: typing.Optional[str] = None,
strict_hostname: bool = True, strict_hostname: bool = True,
strict_trailing_slash: bool = False, strict_trailing_slash: bool = False,
) -> typing.Callable: ) -> typing.Callable[[RouteHandler], RouteHandler]:
""" """
Decorator for binding a function to a route based on the URL path. Decorator for binding a function to a route based on the URL path.
@ -287,7 +315,7 @@ class JetforceApplication:
return wrap return wrap
def default_callback(self, request: Request, **_) -> Response: def default_callback(self, request: Request, **_: typing.Any) -> Response:
""" """
Set the error response based on the URL type. Set the error response based on the URL type.
""" """

View File

@ -1,6 +1,15 @@
import typing import typing
from .base import Request, ResponseType, Status from .base import (
ApplicationCallable,
ApplicationResponse,
EnvironDict,
Request,
Status,
WriteStatusCallable,
)
ApplicationMap = typing.Dict[typing.Optional[str], ApplicationCallable]
class CompositeApplication: class CompositeApplication:
@ -11,7 +20,7 @@ class CompositeApplication:
two or more applications behind a single jetforce server. two or more applications behind a single jetforce server.
""" """
def __init__(self, application_map: typing.Dict[typing.Optional[str], typing.Any]): def __init__(self, application_map: ApplicationMap):
""" """
Initialize the application by providing a mapping of hostname -> app Initialize the application by providing a mapping of hostname -> app
key pairs. A hostname of `None` is a special key that can be used as key pairs. A hostname of `None` is a special key that can be used as
@ -29,8 +38,8 @@ class CompositeApplication:
self.application_map = application_map self.application_map = application_map
def __call__( def __call__(
self, environ: dict, send_status: typing.Callable self, environ: EnvironDict, send_status: WriteStatusCallable
) -> typing.Iterator[ResponseType]: ) -> ApplicationResponse:
try: try:
request = Request(environ) request = Request(environ)
except Exception: except Exception:

View File

@ -7,6 +7,7 @@ import typing
import urllib.parse import urllib.parse
from .base import ( from .base import (
EnvironDict,
JetforceApplication, JetforceApplication,
RateLimiter, RateLimiter,
Request, Request,
@ -33,6 +34,8 @@ class StaticDirectoryApplication(JetforceApplication):
# Chunk size for streaming files, taken from the twisted FileSender class # Chunk size for streaming files, taken from the twisted FileSender class
CHUNK_SIZE = 2 ** 14 CHUNK_SIZE = 2 ** 14
mimetypes: mimetypes.MimeTypes
def __init__( def __init__(
self, self,
root_directory: str = "/var/gemini", root_directory: str = "/var/gemini",
@ -57,8 +60,9 @@ class StaticDirectoryApplication(JetforceApplication):
if os.path.isfile(fn): if os.path.isfile(fn):
self.mimetypes.read(fn) self.mimetypes.read(fn)
self.mimetypes.add_type("text/gemini", ".gmi") # This is a valid method but the type stubs are incorrect
self.mimetypes.add_type("text/gemini", ".gemini") self.mimetypes.add_type("text/gemini", ".gmi") # type: ignore
self.mimetypes.add_type("text/gemini", ".gemini") # type: ignore
def serve_static_file(self, request: Request) -> Response: def serve_static_file(self, request: Request) -> Response:
""" """
@ -143,7 +147,9 @@ class StaticDirectoryApplication(JetforceApplication):
else: else:
return Response(Status.NOT_FOUND, "Not Found") return Response(Status.NOT_FOUND, "Not Found")
def run_cgi_script(self, filesystem_path: pathlib.Path, environ: dict) -> Response: def run_cgi_script(
self, filesystem_path: typing.Union[str, pathlib.Path], environ: EnvironDict
) -> Response:
""" """
Execute the given file as a CGI script and return the script's stdout Execute the given file as a CGI script and return the script's stdout
stream to the client. stream to the client.
@ -224,7 +230,7 @@ class StaticDirectoryApplication(JetforceApplication):
meta += f"; lang={self.default_lang}" meta += f"; lang={self.default_lang}"
return meta return meta
def default_callback(self, request: Request, **_) -> Response: def default_callback(self, request: Request, **_: typing.Any) -> Response:
""" """
Since the StaticDirectoryApplication only serves gemini URLs, return Since the StaticDirectoryApplication only serves gemini URLs, return
a proxy request refused for suspicious URLs. a proxy request refused for suspicious URLs.

View File

@ -11,9 +11,10 @@ from twisted.internet.error import ConnectionClosed
from twisted.internet.protocol import connectionDone from twisted.internet.protocol import connectionDone
from twisted.internet.task import deferLater from twisted.internet.task import deferLater
from twisted.protocols.basic import LineOnlyReceiver from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure
from .__version__ import __version__ from .__version__ import __version__
from .app.base import JetforceApplication, Status from .app.base import ApplicationCallable, EnvironDict, Status
from .tls import inspect_certificate from .tls import inspect_certificate
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
@ -49,12 +50,12 @@ class GeminiProtocol(LineOnlyReceiver):
response_buffer: str response_buffer: str
response_size: int response_size: int
def __init__(self, server: GeminiServer, app: JetforceApplication): def __init__(self, server: GeminiServer, app: ApplicationCallable):
self.server = server self.server = server
self.app = app self.app = app
self._currently_deferred: typing.Optional[Deferred] = None self._currently_deferred: typing.Optional[Deferred] = None
def connectionMade(self): def connectionMade(self) -> None:
""" """
This is invoked by twisted after the connection is first established. This is invoked by twisted after the connection is first established.
""" """
@ -63,7 +64,7 @@ class GeminiProtocol(LineOnlyReceiver):
self.response_buffer = "" self.response_buffer = ""
self.client_addr = self.transport.getPeer() self.client_addr = self.transport.getPeer()
def connectionLost(self, reason=connectionDone): def connectionLost(self, reason: Failure = connectionDone) -> None:
""" """
This is invoked by twisted after the connection has been closed. This is invoked by twisted after the connection has been closed.
""" """
@ -71,20 +72,20 @@ class GeminiProtocol(LineOnlyReceiver):
self._currently_deferred.errback(reason) self._currently_deferred.errback(reason)
self._currently_deferred = None self._currently_deferred = None
def lineReceived(self, line): def lineReceived(self, line: bytes) -> Deferred:
""" """
This method is invoked by LineOnlyReceiver for every incoming line. This method is invoked by LineOnlyReceiver for every incoming line.
""" """
self.request = line self.request = line
return ensureDeferred(self._handle_request_noblock()) return ensureDeferred(self._handle_request_noblock())
def lineLengthExceeded(self, line): def lineLengthExceeded(self, line: bytes) -> None:
""" """
Called when the maximum line length has been reached. Called when the maximum line length has been reached.
""" """
return self.finish_connection() self.finish_connection()
def finish_connection(self): def finish_connection(self) -> None:
""" """
Send the TLS "close_notify" alert and then immediately close the TCP Send the TLS "close_notify" alert and then immediately close the TCP
connection without waiting for the client to respond with it's own connection without waiting for the client to respond with it's own
@ -111,7 +112,7 @@ class GeminiProtocol(LineOnlyReceiver):
# part of the above TLS shutdown. # part of the above TLS shutdown.
self.transport.transport.loseConnection() self.transport.transport.loseConnection()
async def _handle_request_noblock(self): async def _handle_request_noblock(self) -> None:
""" """
Handle the gemini request and write the raw response to the socket. Handle the gemini request and write the raw response to the socket.
@ -170,14 +171,14 @@ class GeminiProtocol(LineOnlyReceiver):
self.log_request() self.log_request()
self.finish_connection() self.finish_connection()
async def track_deferred(self, deferred: Deferred): async def track_deferred(self, deferred: Deferred) -> typing.Union[str, bytes]:
self._currently_deferred = deferred self._currently_deferred = deferred
try: try:
return await deferred return await deferred
finally: finally:
self._currently_deferred = None self._currently_deferred = None
def build_environ(self) -> typing.Dict[str, typing.Any]: def build_environ(self) -> EnvironDict:
""" """
Construct a dictionary that will be passed to the application handler. Construct a dictionary that will be passed to the application handler.

View File

@ -4,13 +4,14 @@ import socket
import sys import sys
import typing import typing
from twisted.internet import reactor from twisted.internet import reactor as _reactor
from twisted.internet.base import ReactorBase from twisted.internet.base import ReactorBase
from twisted.internet.endpoints import SSL4ServerEndpoint from twisted.internet.endpoints import SSL4ServerEndpoint
from twisted.internet.protocol import Factory from twisted.internet.protocol import Factory
from twisted.internet.tcp import Port from twisted.internet.tcp import Port
from .__version__ import __version__ from .__version__ import __version__
from .app.base import ApplicationCallable
from .protocol import GeminiProtocol from .protocol import GeminiProtocol
from .tls import GeminiCertificateOptions, generate_ad_hoc_certificate from .tls import GeminiCertificateOptions, generate_ad_hoc_certificate
@ -49,8 +50,8 @@ class GeminiServer(Factory):
def __init__( def __init__(
self, self,
app: typing.Callable, app: ApplicationCallable,
reactor: ReactorBase = reactor, reactor: ReactorBase = _reactor,
host: str = "127.0.0.1", host: str = "127.0.0.1",
port: int = 1965, port: int = 1965,
hostname: str = "localhost", hostname: str = "localhost",
@ -95,7 +96,7 @@ class GeminiServer(Factory):
else: else:
self.log_message(f"Listening on [{sock_ip}]:{sock_port}") self.log_message(f"Listening on [{sock_ip}]:{sock_port}")
def buildProtocol(self, addr) -> GeminiProtocol: def buildProtocol(self, addr: typing.Any) -> GeminiProtocol:
""" """
This method is invoked by twisted once for every incoming connection. This method is invoked by twisted once for every incoming connection.

View File

@ -15,7 +15,7 @@ from twisted.python.randbytes import secureRandom
COMMON_NAME = x509.NameOID.COMMON_NAME COMMON_NAME = x509.NameOID.COMMON_NAME
def inspect_certificate(cert: x509) -> dict: def inspect_certificate(cert: x509.Certificate) -> typing.Dict[str, object]:
""" """
Extract useful fields from a x509 client certificate object. Extract useful fields from a x509 client certificate object.
""" """
@ -66,7 +66,7 @@ def generate_ad_hoc_certificate(hostname: str) -> typing.Tuple[str, str]:
subject_name = x509.Name([common_name]) subject_name = x509.Name([common_name])
not_valid_before = datetime.datetime.utcnow() not_valid_before = datetime.datetime.utcnow()
not_valid_after = not_valid_before + datetime.timedelta(days=365) not_valid_after = not_valid_before + datetime.timedelta(days=365)
certificate = x509.CertificateBuilder( cert_builder = x509.CertificateBuilder(
subject_name=subject_name, subject_name=subject_name,
issuer_name=subject_name, issuer_name=subject_name,
public_key=private_key.public_key(), public_key=private_key.public_key(),
@ -74,7 +74,7 @@ def generate_ad_hoc_certificate(hostname: str) -> typing.Tuple[str, str]:
not_valid_before=not_valid_before, not_valid_before=not_valid_before,
not_valid_after=not_valid_after, not_valid_after=not_valid_after,
) )
certificate = certificate.sign(private_key, hashes.SHA256(), backend) certificate = cert_builder.sign(private_key, hashes.SHA256(), backend)
with open(certfile, "wb") as fp: with open(certfile, "wb") as fp:
# noinspection PyTypeChecker # noinspection PyTypeChecker
cert_data = certificate.public_bytes(serialization.Encoding.PEM) cert_data = certificate.public_bytes(serialization.Encoding.PEM)
@ -96,6 +96,8 @@ class GeminiCertificateOptions(CertificateOptions):
https://github.com/twisted/twisted/blob/trunk/src/twisted/internet/_sslverify.py https://github.com/twisted/twisted/blob/trunk/src/twisted/internet/_sslverify.py
""" """
_acceptableProtocols: typing.List[bytes]
def verify_callback( def verify_callback(
self, self,
conn: OpenSSL.SSL.Connection, conn: OpenSSL.SSL.Connection,

View File

@ -6,6 +6,7 @@ import argparse
import socket import socket
import ssl import ssl
import sys import sys
import typing
import urllib.parse import urllib.parse
context = ssl.create_default_context() context = ssl.create_default_context()
@ -13,7 +14,12 @@ context.check_hostname = False
context.verify_mode = ssl.CERT_NONE context.verify_mode = ssl.CERT_NONE
def fetch(url, host=None, port=None, use_sni=False): def fetch(
url: str,
host: typing.Optional[str] = None,
port: typing.Optional[int] = None,
use_sni: bool = False,
) -> None:
parsed_url = urllib.parse.urlparse(url) parsed_url = urllib.parse.urlparse(url)
if not parsed_url.scheme: if not parsed_url.scheme:
parsed_url = urllib.parse.urlparse(f"gemini://{url}") parsed_url = urllib.parse.urlparse(f"gemini://{url}")
@ -38,7 +44,7 @@ def fetch(url, host=None, port=None, use_sni=False):
# ssock.unwrap() # ssock.unwrap()
def run_client(): def run_client() -> None:
# fmt: off # fmt: off
parser = argparse.ArgumentParser(description="A simple gemini client") parser = argparse.ArgumentParser(description="A simple gemini client")
parser.add_argument("url") parser.add_argument("url")
@ -59,6 +65,7 @@ def run_client():
context.set_alpn_protocols([args.tls_alpn_protocol]) context.set_alpn_protocols([args.tls_alpn_protocol])
if args.tls_keylog: if args.tls_keylog:
# This is a "private" variable that the stdlib exposes for debugging
context.keylog_filename = args.tls_keylog context.keylog_filename = args.tls_keylog
fetch(args.url, args.host, args.port, args.tls_enable_sni) fetch(args.url, args.host, args.port, args.tls_enable_sni)

View File

@ -3,7 +3,7 @@ import codecs
import setuptools import setuptools
def long_description(): def long_description() -> str:
with codecs.open("README.md", encoding="utf8") as f: with codecs.open("README.md", encoding="utf8") as f:
return f.read() return f.read()