Type hints (#49)

* Better type hint coverage

* linting fixes

* Keep the old ResponseType name for backwards compatibility

* Remove unnecessary assert statement

Co-authored-by: Michael Lazar <mlazar@doctorondemand.com>
This commit is contained in:
Michael Lazar 2020-12-29 00:03:07 -05:00 committed by GitHub
parent 637025c8c3
commit 135dbda878
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 112 additions and 55 deletions

View File

@ -6,10 +6,13 @@
- Added support for international domain names using IDN encoding.
#### New Features
- Several fixes & improvements to python type hinting coverage.
- Added a ``py.typed`` file to indicate project support for type hints.
#### Bug Fixes
- Added py.typed file to indicate that the jetforce python library has support
for type hints.
- Fixed a bug where TLS_CLIENT_AUTHORISED would sometimes be set to
``True``/``False`` instead of ``1``/``0``.

View File

@ -190,7 +190,7 @@ additional modification by the server.
<dl>
<dt>GATEWAY_INTERFACE</dt>
<dd>
CGI version (for compatability with RFC 3785).<br>
CGI version (for compatibility with RFC 3785).<br>
<em>Example: "CGI/1.1"</em>
</dd>
@ -270,7 +270,7 @@ Additional CGI variables will be included only when the client connection uses a
<dt>AUTH_TYPE</dt>
<dd>
Authentication type (for compatability with RFC 3785).<br>
Authentication type (for compatibility with RFC 3785).<br>
<em>Example: "CERTIFICATE"</em>
</dd>

View File

@ -1,5 +1,5 @@
#!/usr/local/bin/python3.7
"""
r"""
CGI script that requests user supplied text using the INPUT status, and
pipes it into the `cowsay` program.

View File

@ -59,8 +59,8 @@ def deferred_counter():
eventually run in the main event loop.
"""
def delayed_callback(x):
return f"{x}\r\n"
def delayed_callback(var):
return f"{var}\r\n"
for x in range(10):
yield deferLater(reactor, 1, delayed_callback, x)

View File

@ -106,7 +106,7 @@ group.add_argument(
)
def main():
def main() -> None:
args = parser.parse_args()
rate_limiter = RateLimiter(args.rate_limit) if args.rate_limit else None
app = StaticDirectoryApplication(

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import dataclasses
import re
import time
@ -7,7 +9,13 @@ from urllib.parse import unquote, urlparse
from twisted.internet.defer import Deferred
EnvironDict = typing.Dict[str, object]
ResponseType = typing.Union[str, bytes, Deferred]
ApplicationResponse = typing.Iterable[ResponseType]
WriteStatusCallable = typing.Callable[[int, str], None]
ApplicationCallable = typing.Callable[
[EnvironDict, WriteStatusCallable], ApplicationResponse
]
class Status:
@ -45,9 +53,19 @@ class Request:
Object that encapsulates information about a single gemini request.
"""
def __init__(self, environ: dict):
environ: EnvironDict
url: str
scheme: str
hostname: str
port: typing.Optional[int]
path: str
params: str
query: str
fragment: str
def __init__(self, environ: EnvironDict):
self.environ = environ
self.url = environ["GEMINI_URL"]
self.url = typing.cast(str, environ["GEMINI_URL"])
url_parts = urlparse(self.url)
if not url_parts.hostname:
@ -84,7 +102,10 @@ class Response:
status: int
meta: str
body: typing.Union[None, ResponseType, typing.Iterable[ResponseType]] = None
body: typing.Union[None, ResponseType, ApplicationResponse] = None
RouteHandler = typing.Callable[..., Response]
@dataclasses.dataclass
@ -101,7 +122,7 @@ class RoutePattern:
strict_port: bool = True
strict_trailing_slash: bool = False
def match(self, request: Request) -> typing.Optional[re.Match]:
def match(self, request: Request) -> typing.Optional[re.Match[str]]:
"""
Check if the given request URL matches this route pattern.
"""
@ -112,12 +133,12 @@ class RoutePattern:
server_port = request.environ["SERVER_PORT"]
if self.strict_hostname and request.hostname != server_hostname:
return
return None
if self.strict_port and request.port is not None:
if request.port != server_port:
return
return None
if self.scheme and self.scheme != request.scheme:
return
return None
if self.strict_trailing_slash:
request_path = request.path
@ -127,9 +148,6 @@ class RoutePattern:
return re.fullmatch(self.path, request_path)
RouteHandler = typing.Callable[..., Response]
class RateLimiter:
"""
A class that can be used to apply rate-limiting to endpoints.
@ -144,6 +162,11 @@ class RateLimiter:
RE = re.compile("(?P<number>[0-9]+)/(?P<period>[0-9]+)?(?P<unit>[smhd])")
number: int
period: int
next_timestamp: float
rate_counter: typing.Dict[typing.Any, int]
def __init__(self, rate: str) -> None:
match = self.RE.fullmatch(rate)
if not match:
@ -166,7 +189,7 @@ class RateLimiter:
self.next_timestamp = time.time() + self.period
self.rate_counter = defaultdict(int)
def get_key(self, request: Request) -> typing.Optional[str]:
def get_key(self, request: Request) -> typing.Any:
"""
Rate limit based on the client's IP-address.
"""
@ -190,6 +213,8 @@ class RateLimiter:
msg = f"Rate limit exceeded, wait {time_left:.0f} seconds."
return Response(Status.SLOW_DOWN, msg)
return None
def apply(self, wrapped_func: RouteHandler) -> RouteHandler:
"""
Decorator to apply rate limiting to an individual application route.
@ -203,7 +228,7 @@ class RateLimiter:
return Response(Status.SUCCESS, "text/gemini", "hello world!")
"""
def wrapper(request: Request, **kwargs) -> Response:
def wrapper(request: Request, **kwargs: typing.Any) -> Response:
response = self.check(request)
if response:
return response
@ -224,13 +249,16 @@ class JetforceApplication:
how to accomplish this.
"""
rate_limiter: typing.Optional[RateLimiter]
routes: typing.List[typing.Tuple[RoutePattern, RouteHandler]]
def __init__(self, rate_limiter: typing.Optional[RateLimiter] = None):
self.rate_limiter = rate_limiter
self.routes: typing.List[typing.Tuple[RoutePattern, RouteHandler]] = []
self.routes = []
def __call__(
self, environ: dict, send_status: typing.Callable
) -> typing.Iterator[ResponseType]:
self, environ: EnvironDict, send_status: WriteStatusCallable
) -> ApplicationResponse:
try:
request = Request(environ)
except Exception:
@ -245,7 +273,7 @@ class JetforceApplication:
for route_pattern, callback in self.routes[::-1]:
match = route_pattern.match(request)
if route_pattern.match(request):
if match:
callback_kwargs = match.groupdict()
break
else:
@ -267,7 +295,7 @@ class JetforceApplication:
hostname: typing.Optional[str] = None,
strict_hostname: bool = True,
strict_trailing_slash: bool = False,
) -> typing.Callable:
) -> typing.Callable[[RouteHandler], RouteHandler]:
"""
Decorator for binding a function to a route based on the URL path.
@ -287,7 +315,7 @@ class JetforceApplication:
return wrap
def default_callback(self, request: Request, **_) -> Response:
def default_callback(self, request: Request, **_: typing.Any) -> Response:
"""
Set the error response based on the URL type.
"""

View File

@ -1,6 +1,15 @@
import typing
from .base import Request, ResponseType, Status
from .base import (
ApplicationCallable,
ApplicationResponse,
EnvironDict,
Request,
Status,
WriteStatusCallable,
)
ApplicationMap = typing.Dict[typing.Optional[str], ApplicationCallable]
class CompositeApplication:
@ -11,7 +20,7 @@ class CompositeApplication:
two or more applications behind a single jetforce server.
"""
def __init__(self, application_map: typing.Dict[typing.Optional[str], typing.Any]):
def __init__(self, application_map: ApplicationMap):
"""
Initialize the application by providing a mapping of hostname -> app
key pairs. A hostname of `None` is a special key that can be used as
@ -29,8 +38,8 @@ class CompositeApplication:
self.application_map = application_map
def __call__(
self, environ: dict, send_status: typing.Callable
) -> typing.Iterator[ResponseType]:
self, environ: EnvironDict, send_status: WriteStatusCallable
) -> ApplicationResponse:
try:
request = Request(environ)
except Exception:

View File

@ -7,6 +7,7 @@ import typing
import urllib.parse
from .base import (
EnvironDict,
JetforceApplication,
RateLimiter,
Request,
@ -33,6 +34,8 @@ class StaticDirectoryApplication(JetforceApplication):
# Chunk size for streaming files, taken from the twisted FileSender class
CHUNK_SIZE = 2 ** 14
mimetypes: mimetypes.MimeTypes
def __init__(
self,
root_directory: str = "/var/gemini",
@ -57,8 +60,9 @@ class StaticDirectoryApplication(JetforceApplication):
if os.path.isfile(fn):
self.mimetypes.read(fn)
self.mimetypes.add_type("text/gemini", ".gmi")
self.mimetypes.add_type("text/gemini", ".gemini")
# This is a valid method but the type stubs are incorrect
self.mimetypes.add_type("text/gemini", ".gmi") # type: ignore
self.mimetypes.add_type("text/gemini", ".gemini") # type: ignore
def serve_static_file(self, request: Request) -> Response:
"""
@ -143,7 +147,9 @@ class StaticDirectoryApplication(JetforceApplication):
else:
return Response(Status.NOT_FOUND, "Not Found")
def run_cgi_script(self, filesystem_path: pathlib.Path, environ: dict) -> Response:
def run_cgi_script(
self, filesystem_path: typing.Union[str, pathlib.Path], environ: EnvironDict
) -> Response:
"""
Execute the given file as a CGI script and return the script's stdout
stream to the client.
@ -224,7 +230,7 @@ class StaticDirectoryApplication(JetforceApplication):
meta += f"; lang={self.default_lang}"
return meta
def default_callback(self, request: Request, **_) -> Response:
def default_callback(self, request: Request, **_: typing.Any) -> Response:
"""
Since the StaticDirectoryApplication only serves gemini URLs, return
a proxy request refused for suspicious URLs.

View File

@ -11,9 +11,10 @@ from twisted.internet.error import ConnectionClosed
from twisted.internet.protocol import connectionDone
from twisted.internet.task import deferLater
from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure
from .__version__ import __version__
from .app.base import JetforceApplication, Status
from .app.base import ApplicationCallable, EnvironDict, Status
from .tls import inspect_certificate
if typing.TYPE_CHECKING:
@ -49,12 +50,12 @@ class GeminiProtocol(LineOnlyReceiver):
response_buffer: str
response_size: int
def __init__(self, server: GeminiServer, app: JetforceApplication):
def __init__(self, server: GeminiServer, app: ApplicationCallable):
self.server = server
self.app = app
self._currently_deferred: typing.Optional[Deferred] = None
def connectionMade(self):
def connectionMade(self) -> None:
"""
This is invoked by twisted after the connection is first established.
"""
@ -63,7 +64,7 @@ class GeminiProtocol(LineOnlyReceiver):
self.response_buffer = ""
self.client_addr = self.transport.getPeer()
def connectionLost(self, reason=connectionDone):
def connectionLost(self, reason: Failure = connectionDone) -> None:
"""
This is invoked by twisted after the connection has been closed.
"""
@ -71,20 +72,20 @@ class GeminiProtocol(LineOnlyReceiver):
self._currently_deferred.errback(reason)
self._currently_deferred = None
def lineReceived(self, line):
def lineReceived(self, line: bytes) -> Deferred:
"""
This method is invoked by LineOnlyReceiver for every incoming line.
"""
self.request = line
return ensureDeferred(self._handle_request_noblock())
def lineLengthExceeded(self, line):
def lineLengthExceeded(self, line: bytes) -> None:
"""
Called when the maximum line length has been reached.
"""
return self.finish_connection()
self.finish_connection()
def finish_connection(self):
def finish_connection(self) -> None:
"""
Send the TLS "close_notify" alert and then immediately close the TCP
connection without waiting for the client to respond with it's own
@ -111,7 +112,7 @@ class GeminiProtocol(LineOnlyReceiver):
# part of the above TLS shutdown.
self.transport.transport.loseConnection()
async def _handle_request_noblock(self):
async def _handle_request_noblock(self) -> None:
"""
Handle the gemini request and write the raw response to the socket.
@ -170,14 +171,14 @@ class GeminiProtocol(LineOnlyReceiver):
self.log_request()
self.finish_connection()
async def track_deferred(self, deferred: Deferred):
async def track_deferred(self, deferred: Deferred) -> typing.Union[str, bytes]:
self._currently_deferred = deferred
try:
return await deferred
finally:
self._currently_deferred = None
def build_environ(self) -> typing.Dict[str, typing.Any]:
def build_environ(self) -> EnvironDict:
"""
Construct a dictionary that will be passed to the application handler.

View File

@ -4,13 +4,14 @@ import socket
import sys
import typing
from twisted.internet import reactor
from twisted.internet import reactor as _reactor
from twisted.internet.base import ReactorBase
from twisted.internet.endpoints import SSL4ServerEndpoint
from twisted.internet.protocol import Factory
from twisted.internet.tcp import Port
from .__version__ import __version__
from .app.base import ApplicationCallable
from .protocol import GeminiProtocol
from .tls import GeminiCertificateOptions, generate_ad_hoc_certificate
@ -49,8 +50,8 @@ class GeminiServer(Factory):
def __init__(
self,
app: typing.Callable,
reactor: ReactorBase = reactor,
app: ApplicationCallable,
reactor: ReactorBase = _reactor,
host: str = "127.0.0.1",
port: int = 1965,
hostname: str = "localhost",
@ -95,7 +96,7 @@ class GeminiServer(Factory):
else:
self.log_message(f"Listening on [{sock_ip}]:{sock_port}")
def buildProtocol(self, addr) -> GeminiProtocol:
def buildProtocol(self, addr: typing.Any) -> GeminiProtocol:
"""
This method is invoked by twisted once for every incoming connection.

View File

@ -15,7 +15,7 @@ from twisted.python.randbytes import secureRandom
COMMON_NAME = x509.NameOID.COMMON_NAME
def inspect_certificate(cert: x509) -> dict:
def inspect_certificate(cert: x509.Certificate) -> typing.Dict[str, object]:
"""
Extract useful fields from a x509 client certificate object.
"""
@ -66,7 +66,7 @@ def generate_ad_hoc_certificate(hostname: str) -> typing.Tuple[str, str]:
subject_name = x509.Name([common_name])
not_valid_before = datetime.datetime.utcnow()
not_valid_after = not_valid_before + datetime.timedelta(days=365)
certificate = x509.CertificateBuilder(
cert_builder = x509.CertificateBuilder(
subject_name=subject_name,
issuer_name=subject_name,
public_key=private_key.public_key(),
@ -74,7 +74,7 @@ def generate_ad_hoc_certificate(hostname: str) -> typing.Tuple[str, str]:
not_valid_before=not_valid_before,
not_valid_after=not_valid_after,
)
certificate = certificate.sign(private_key, hashes.SHA256(), backend)
certificate = cert_builder.sign(private_key, hashes.SHA256(), backend)
with open(certfile, "wb") as fp:
# noinspection PyTypeChecker
cert_data = certificate.public_bytes(serialization.Encoding.PEM)
@ -96,6 +96,8 @@ class GeminiCertificateOptions(CertificateOptions):
https://github.com/twisted/twisted/blob/trunk/src/twisted/internet/_sslverify.py
"""
_acceptableProtocols: typing.List[bytes]
def verify_callback(
self,
conn: OpenSSL.SSL.Connection,

View File

@ -6,6 +6,7 @@ import argparse
import socket
import ssl
import sys
import typing
import urllib.parse
context = ssl.create_default_context()
@ -13,7 +14,12 @@ context.check_hostname = False
context.verify_mode = ssl.CERT_NONE
def fetch(url, host=None, port=None, use_sni=False):
def fetch(
url: str,
host: typing.Optional[str] = None,
port: typing.Optional[int] = None,
use_sni: bool = False,
) -> None:
parsed_url = urllib.parse.urlparse(url)
if not parsed_url.scheme:
parsed_url = urllib.parse.urlparse(f"gemini://{url}")
@ -38,7 +44,7 @@ def fetch(url, host=None, port=None, use_sni=False):
# ssock.unwrap()
def run_client():
def run_client() -> None:
# fmt: off
parser = argparse.ArgumentParser(description="A simple gemini client")
parser.add_argument("url")
@ -59,6 +65,7 @@ def run_client():
context.set_alpn_protocols([args.tls_alpn_protocol])
if args.tls_keylog:
# This is a "private" variable that the stdlib exposes for debugging
context.keylog_filename = args.tls_keylog
fetch(args.url, args.host, args.port, args.tls_enable_sni)

View File

@ -3,7 +3,7 @@ import codecs
import setuptools
def long_description():
def long_description() -> str:
with codecs.open("README.md", encoding="utf8") as f:
return f.read()