Type hints (#49)
* Better type hint coverage * linting fixes * Keep the old ResponseType name for backwards compatibility * Remove unnecessary assert statement Co-authored-by: Michael Lazar <mlazar@doctorondemand.com>
This commit is contained in:
parent
637025c8c3
commit
135dbda878
|
@ -6,10 +6,13 @@
|
|||
|
||||
- Added support for international domain names using IDN encoding.
|
||||
|
||||
#### New Features
|
||||
|
||||
- Several fixes & improvements to python type hinting coverage.
|
||||
- Added a ``py.typed`` file to indicate project support for type hints.
|
||||
|
||||
#### Bug Fixes
|
||||
|
||||
- Added py.typed file to indicate that the jetforce python library has support
|
||||
for type hints.
|
||||
- Fixed a bug where TLS_CLIENT_AUTHORISED would sometimes be set to
|
||||
``True``/``False`` instead of ``1``/``0``.
|
||||
|
||||
|
|
|
@ -190,7 +190,7 @@ additional modification by the server.
|
|||
<dl>
|
||||
<dt>GATEWAY_INTERFACE</dt>
|
||||
<dd>
|
||||
CGI version (for compatability with RFC 3785).<br>
|
||||
CGI version (for compatibility with RFC 3785).<br>
|
||||
<em>Example: "CGI/1.1"</em>
|
||||
</dd>
|
||||
|
||||
|
@ -270,7 +270,7 @@ Additional CGI variables will be included only when the client connection uses a
|
|||
|
||||
<dt>AUTH_TYPE</dt>
|
||||
<dd>
|
||||
Authentication type (for compatability with RFC 3785).<br>
|
||||
Authentication type (for compatibility with RFC 3785).<br>
|
||||
<em>Example: "CERTIFICATE"</em>
|
||||
</dd>
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
#!/usr/local/bin/python3.7
|
||||
"""
|
||||
r"""
|
||||
CGI script that requests user supplied text using the INPUT status, and
|
||||
pipes it into the `cowsay` program.
|
||||
|
||||
|
|
|
@ -59,8 +59,8 @@ def deferred_counter():
|
|||
eventually run in the main event loop.
|
||||
"""
|
||||
|
||||
def delayed_callback(x):
|
||||
return f"{x}\r\n"
|
||||
def delayed_callback(var):
|
||||
return f"{var}\r\n"
|
||||
|
||||
for x in range(10):
|
||||
yield deferLater(reactor, 1, delayed_callback, x)
|
||||
|
|
|
@ -106,7 +106,7 @@ group.add_argument(
|
|||
)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
args = parser.parse_args()
|
||||
rate_limiter = RateLimiter(args.rate_limit) if args.rate_limit else None
|
||||
app = StaticDirectoryApplication(
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import re
|
||||
import time
|
||||
|
@ -7,7 +9,13 @@ from urllib.parse import unquote, urlparse
|
|||
|
||||
from twisted.internet.defer import Deferred
|
||||
|
||||
EnvironDict = typing.Dict[str, object]
|
||||
ResponseType = typing.Union[str, bytes, Deferred]
|
||||
ApplicationResponse = typing.Iterable[ResponseType]
|
||||
WriteStatusCallable = typing.Callable[[int, str], None]
|
||||
ApplicationCallable = typing.Callable[
|
||||
[EnvironDict, WriteStatusCallable], ApplicationResponse
|
||||
]
|
||||
|
||||
|
||||
class Status:
|
||||
|
@ -45,9 +53,19 @@ class Request:
|
|||
Object that encapsulates information about a single gemini request.
|
||||
"""
|
||||
|
||||
def __init__(self, environ: dict):
|
||||
environ: EnvironDict
|
||||
url: str
|
||||
scheme: str
|
||||
hostname: str
|
||||
port: typing.Optional[int]
|
||||
path: str
|
||||
params: str
|
||||
query: str
|
||||
fragment: str
|
||||
|
||||
def __init__(self, environ: EnvironDict):
|
||||
self.environ = environ
|
||||
self.url = environ["GEMINI_URL"]
|
||||
self.url = typing.cast(str, environ["GEMINI_URL"])
|
||||
|
||||
url_parts = urlparse(self.url)
|
||||
if not url_parts.hostname:
|
||||
|
@ -84,7 +102,10 @@ class Response:
|
|||
|
||||
status: int
|
||||
meta: str
|
||||
body: typing.Union[None, ResponseType, typing.Iterable[ResponseType]] = None
|
||||
body: typing.Union[None, ResponseType, ApplicationResponse] = None
|
||||
|
||||
|
||||
RouteHandler = typing.Callable[..., Response]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
@ -101,7 +122,7 @@ class RoutePattern:
|
|||
strict_port: bool = True
|
||||
strict_trailing_slash: bool = False
|
||||
|
||||
def match(self, request: Request) -> typing.Optional[re.Match]:
|
||||
def match(self, request: Request) -> typing.Optional[re.Match[str]]:
|
||||
"""
|
||||
Check if the given request URL matches this route pattern.
|
||||
"""
|
||||
|
@ -112,12 +133,12 @@ class RoutePattern:
|
|||
server_port = request.environ["SERVER_PORT"]
|
||||
|
||||
if self.strict_hostname and request.hostname != server_hostname:
|
||||
return
|
||||
return None
|
||||
if self.strict_port and request.port is not None:
|
||||
if request.port != server_port:
|
||||
return
|
||||
return None
|
||||
if self.scheme and self.scheme != request.scheme:
|
||||
return
|
||||
return None
|
||||
|
||||
if self.strict_trailing_slash:
|
||||
request_path = request.path
|
||||
|
@ -127,9 +148,6 @@ class RoutePattern:
|
|||
return re.fullmatch(self.path, request_path)
|
||||
|
||||
|
||||
RouteHandler = typing.Callable[..., Response]
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""
|
||||
A class that can be used to apply rate-limiting to endpoints.
|
||||
|
@ -144,6 +162,11 @@ class RateLimiter:
|
|||
|
||||
RE = re.compile("(?P<number>[0-9]+)/(?P<period>[0-9]+)?(?P<unit>[smhd])")
|
||||
|
||||
number: int
|
||||
period: int
|
||||
next_timestamp: float
|
||||
rate_counter: typing.Dict[typing.Any, int]
|
||||
|
||||
def __init__(self, rate: str) -> None:
|
||||
match = self.RE.fullmatch(rate)
|
||||
if not match:
|
||||
|
@ -166,7 +189,7 @@ class RateLimiter:
|
|||
self.next_timestamp = time.time() + self.period
|
||||
self.rate_counter = defaultdict(int)
|
||||
|
||||
def get_key(self, request: Request) -> typing.Optional[str]:
|
||||
def get_key(self, request: Request) -> typing.Any:
|
||||
"""
|
||||
Rate limit based on the client's IP-address.
|
||||
"""
|
||||
|
@ -190,6 +213,8 @@ class RateLimiter:
|
|||
msg = f"Rate limit exceeded, wait {time_left:.0f} seconds."
|
||||
return Response(Status.SLOW_DOWN, msg)
|
||||
|
||||
return None
|
||||
|
||||
def apply(self, wrapped_func: RouteHandler) -> RouteHandler:
|
||||
"""
|
||||
Decorator to apply rate limiting to an individual application route.
|
||||
|
@ -203,7 +228,7 @@ class RateLimiter:
|
|||
return Response(Status.SUCCESS, "text/gemini", "hello world!")
|
||||
"""
|
||||
|
||||
def wrapper(request: Request, **kwargs) -> Response:
|
||||
def wrapper(request: Request, **kwargs: typing.Any) -> Response:
|
||||
response = self.check(request)
|
||||
if response:
|
||||
return response
|
||||
|
@ -224,13 +249,16 @@ class JetforceApplication:
|
|||
how to accomplish this.
|
||||
"""
|
||||
|
||||
rate_limiter: typing.Optional[RateLimiter]
|
||||
routes: typing.List[typing.Tuple[RoutePattern, RouteHandler]]
|
||||
|
||||
def __init__(self, rate_limiter: typing.Optional[RateLimiter] = None):
|
||||
self.rate_limiter = rate_limiter
|
||||
self.routes: typing.List[typing.Tuple[RoutePattern, RouteHandler]] = []
|
||||
self.routes = []
|
||||
|
||||
def __call__(
|
||||
self, environ: dict, send_status: typing.Callable
|
||||
) -> typing.Iterator[ResponseType]:
|
||||
self, environ: EnvironDict, send_status: WriteStatusCallable
|
||||
) -> ApplicationResponse:
|
||||
try:
|
||||
request = Request(environ)
|
||||
except Exception:
|
||||
|
@ -245,7 +273,7 @@ class JetforceApplication:
|
|||
|
||||
for route_pattern, callback in self.routes[::-1]:
|
||||
match = route_pattern.match(request)
|
||||
if route_pattern.match(request):
|
||||
if match:
|
||||
callback_kwargs = match.groupdict()
|
||||
break
|
||||
else:
|
||||
|
@ -267,7 +295,7 @@ class JetforceApplication:
|
|||
hostname: typing.Optional[str] = None,
|
||||
strict_hostname: bool = True,
|
||||
strict_trailing_slash: bool = False,
|
||||
) -> typing.Callable:
|
||||
) -> typing.Callable[[RouteHandler], RouteHandler]:
|
||||
"""
|
||||
Decorator for binding a function to a route based on the URL path.
|
||||
|
||||
|
@ -287,7 +315,7 @@ class JetforceApplication:
|
|||
|
||||
return wrap
|
||||
|
||||
def default_callback(self, request: Request, **_) -> Response:
|
||||
def default_callback(self, request: Request, **_: typing.Any) -> Response:
|
||||
"""
|
||||
Set the error response based on the URL type.
|
||||
"""
|
||||
|
|
|
@ -1,6 +1,15 @@
|
|||
import typing
|
||||
|
||||
from .base import Request, ResponseType, Status
|
||||
from .base import (
|
||||
ApplicationCallable,
|
||||
ApplicationResponse,
|
||||
EnvironDict,
|
||||
Request,
|
||||
Status,
|
||||
WriteStatusCallable,
|
||||
)
|
||||
|
||||
ApplicationMap = typing.Dict[typing.Optional[str], ApplicationCallable]
|
||||
|
||||
|
||||
class CompositeApplication:
|
||||
|
@ -11,7 +20,7 @@ class CompositeApplication:
|
|||
two or more applications behind a single jetforce server.
|
||||
"""
|
||||
|
||||
def __init__(self, application_map: typing.Dict[typing.Optional[str], typing.Any]):
|
||||
def __init__(self, application_map: ApplicationMap):
|
||||
"""
|
||||
Initialize the application by providing a mapping of hostname -> app
|
||||
key pairs. A hostname of `None` is a special key that can be used as
|
||||
|
@ -29,8 +38,8 @@ class CompositeApplication:
|
|||
self.application_map = application_map
|
||||
|
||||
def __call__(
|
||||
self, environ: dict, send_status: typing.Callable
|
||||
) -> typing.Iterator[ResponseType]:
|
||||
self, environ: EnvironDict, send_status: WriteStatusCallable
|
||||
) -> ApplicationResponse:
|
||||
try:
|
||||
request = Request(environ)
|
||||
except Exception:
|
||||
|
|
|
@ -7,6 +7,7 @@ import typing
|
|||
import urllib.parse
|
||||
|
||||
from .base import (
|
||||
EnvironDict,
|
||||
JetforceApplication,
|
||||
RateLimiter,
|
||||
Request,
|
||||
|
@ -33,6 +34,8 @@ class StaticDirectoryApplication(JetforceApplication):
|
|||
# Chunk size for streaming files, taken from the twisted FileSender class
|
||||
CHUNK_SIZE = 2 ** 14
|
||||
|
||||
mimetypes: mimetypes.MimeTypes
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_directory: str = "/var/gemini",
|
||||
|
@ -57,8 +60,9 @@ class StaticDirectoryApplication(JetforceApplication):
|
|||
if os.path.isfile(fn):
|
||||
self.mimetypes.read(fn)
|
||||
|
||||
self.mimetypes.add_type("text/gemini", ".gmi")
|
||||
self.mimetypes.add_type("text/gemini", ".gemini")
|
||||
# This is a valid method but the type stubs are incorrect
|
||||
self.mimetypes.add_type("text/gemini", ".gmi") # type: ignore
|
||||
self.mimetypes.add_type("text/gemini", ".gemini") # type: ignore
|
||||
|
||||
def serve_static_file(self, request: Request) -> Response:
|
||||
"""
|
||||
|
@ -143,7 +147,9 @@ class StaticDirectoryApplication(JetforceApplication):
|
|||
else:
|
||||
return Response(Status.NOT_FOUND, "Not Found")
|
||||
|
||||
def run_cgi_script(self, filesystem_path: pathlib.Path, environ: dict) -> Response:
|
||||
def run_cgi_script(
|
||||
self, filesystem_path: typing.Union[str, pathlib.Path], environ: EnvironDict
|
||||
) -> Response:
|
||||
"""
|
||||
Execute the given file as a CGI script and return the script's stdout
|
||||
stream to the client.
|
||||
|
@ -224,7 +230,7 @@ class StaticDirectoryApplication(JetforceApplication):
|
|||
meta += f"; lang={self.default_lang}"
|
||||
return meta
|
||||
|
||||
def default_callback(self, request: Request, **_) -> Response:
|
||||
def default_callback(self, request: Request, **_: typing.Any) -> Response:
|
||||
"""
|
||||
Since the StaticDirectoryApplication only serves gemini URLs, return
|
||||
a proxy request refused for suspicious URLs.
|
||||
|
|
|
@ -11,9 +11,10 @@ from twisted.internet.error import ConnectionClosed
|
|||
from twisted.internet.protocol import connectionDone
|
||||
from twisted.internet.task import deferLater
|
||||
from twisted.protocols.basic import LineOnlyReceiver
|
||||
from twisted.python.failure import Failure
|
||||
|
||||
from .__version__ import __version__
|
||||
from .app.base import JetforceApplication, Status
|
||||
from .app.base import ApplicationCallable, EnvironDict, Status
|
||||
from .tls import inspect_certificate
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
|
@ -49,12 +50,12 @@ class GeminiProtocol(LineOnlyReceiver):
|
|||
response_buffer: str
|
||||
response_size: int
|
||||
|
||||
def __init__(self, server: GeminiServer, app: JetforceApplication):
|
||||
def __init__(self, server: GeminiServer, app: ApplicationCallable):
|
||||
self.server = server
|
||||
self.app = app
|
||||
self._currently_deferred: typing.Optional[Deferred] = None
|
||||
|
||||
def connectionMade(self):
|
||||
def connectionMade(self) -> None:
|
||||
"""
|
||||
This is invoked by twisted after the connection is first established.
|
||||
"""
|
||||
|
@ -63,7 +64,7 @@ class GeminiProtocol(LineOnlyReceiver):
|
|||
self.response_buffer = ""
|
||||
self.client_addr = self.transport.getPeer()
|
||||
|
||||
def connectionLost(self, reason=connectionDone):
|
||||
def connectionLost(self, reason: Failure = connectionDone) -> None:
|
||||
"""
|
||||
This is invoked by twisted after the connection has been closed.
|
||||
"""
|
||||
|
@ -71,20 +72,20 @@ class GeminiProtocol(LineOnlyReceiver):
|
|||
self._currently_deferred.errback(reason)
|
||||
self._currently_deferred = None
|
||||
|
||||
def lineReceived(self, line):
|
||||
def lineReceived(self, line: bytes) -> Deferred:
|
||||
"""
|
||||
This method is invoked by LineOnlyReceiver for every incoming line.
|
||||
"""
|
||||
self.request = line
|
||||
return ensureDeferred(self._handle_request_noblock())
|
||||
|
||||
def lineLengthExceeded(self, line):
|
||||
def lineLengthExceeded(self, line: bytes) -> None:
|
||||
"""
|
||||
Called when the maximum line length has been reached.
|
||||
"""
|
||||
return self.finish_connection()
|
||||
self.finish_connection()
|
||||
|
||||
def finish_connection(self):
|
||||
def finish_connection(self) -> None:
|
||||
"""
|
||||
Send the TLS "close_notify" alert and then immediately close the TCP
|
||||
connection without waiting for the client to respond with it's own
|
||||
|
@ -111,7 +112,7 @@ class GeminiProtocol(LineOnlyReceiver):
|
|||
# part of the above TLS shutdown.
|
||||
self.transport.transport.loseConnection()
|
||||
|
||||
async def _handle_request_noblock(self):
|
||||
async def _handle_request_noblock(self) -> None:
|
||||
"""
|
||||
Handle the gemini request and write the raw response to the socket.
|
||||
|
||||
|
@ -170,14 +171,14 @@ class GeminiProtocol(LineOnlyReceiver):
|
|||
self.log_request()
|
||||
self.finish_connection()
|
||||
|
||||
async def track_deferred(self, deferred: Deferred):
|
||||
async def track_deferred(self, deferred: Deferred) -> typing.Union[str, bytes]:
|
||||
self._currently_deferred = deferred
|
||||
try:
|
||||
return await deferred
|
||||
finally:
|
||||
self._currently_deferred = None
|
||||
|
||||
def build_environ(self) -> typing.Dict[str, typing.Any]:
|
||||
def build_environ(self) -> EnvironDict:
|
||||
"""
|
||||
Construct a dictionary that will be passed to the application handler.
|
||||
|
||||
|
|
|
@ -4,13 +4,14 @@ import socket
|
|||
import sys
|
||||
import typing
|
||||
|
||||
from twisted.internet import reactor
|
||||
from twisted.internet import reactor as _reactor
|
||||
from twisted.internet.base import ReactorBase
|
||||
from twisted.internet.endpoints import SSL4ServerEndpoint
|
||||
from twisted.internet.protocol import Factory
|
||||
from twisted.internet.tcp import Port
|
||||
|
||||
from .__version__ import __version__
|
||||
from .app.base import ApplicationCallable
|
||||
from .protocol import GeminiProtocol
|
||||
from .tls import GeminiCertificateOptions, generate_ad_hoc_certificate
|
||||
|
||||
|
@ -49,8 +50,8 @@ class GeminiServer(Factory):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
app: typing.Callable,
|
||||
reactor: ReactorBase = reactor,
|
||||
app: ApplicationCallable,
|
||||
reactor: ReactorBase = _reactor,
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 1965,
|
||||
hostname: str = "localhost",
|
||||
|
@ -95,7 +96,7 @@ class GeminiServer(Factory):
|
|||
else:
|
||||
self.log_message(f"Listening on [{sock_ip}]:{sock_port}")
|
||||
|
||||
def buildProtocol(self, addr) -> GeminiProtocol:
|
||||
def buildProtocol(self, addr: typing.Any) -> GeminiProtocol:
|
||||
"""
|
||||
This method is invoked by twisted once for every incoming connection.
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ from twisted.python.randbytes import secureRandom
|
|||
COMMON_NAME = x509.NameOID.COMMON_NAME
|
||||
|
||||
|
||||
def inspect_certificate(cert: x509) -> dict:
|
||||
def inspect_certificate(cert: x509.Certificate) -> typing.Dict[str, object]:
|
||||
"""
|
||||
Extract useful fields from a x509 client certificate object.
|
||||
"""
|
||||
|
@ -66,7 +66,7 @@ def generate_ad_hoc_certificate(hostname: str) -> typing.Tuple[str, str]:
|
|||
subject_name = x509.Name([common_name])
|
||||
not_valid_before = datetime.datetime.utcnow()
|
||||
not_valid_after = not_valid_before + datetime.timedelta(days=365)
|
||||
certificate = x509.CertificateBuilder(
|
||||
cert_builder = x509.CertificateBuilder(
|
||||
subject_name=subject_name,
|
||||
issuer_name=subject_name,
|
||||
public_key=private_key.public_key(),
|
||||
|
@ -74,7 +74,7 @@ def generate_ad_hoc_certificate(hostname: str) -> typing.Tuple[str, str]:
|
|||
not_valid_before=not_valid_before,
|
||||
not_valid_after=not_valid_after,
|
||||
)
|
||||
certificate = certificate.sign(private_key, hashes.SHA256(), backend)
|
||||
certificate = cert_builder.sign(private_key, hashes.SHA256(), backend)
|
||||
with open(certfile, "wb") as fp:
|
||||
# noinspection PyTypeChecker
|
||||
cert_data = certificate.public_bytes(serialization.Encoding.PEM)
|
||||
|
@ -96,6 +96,8 @@ class GeminiCertificateOptions(CertificateOptions):
|
|||
https://github.com/twisted/twisted/blob/trunk/src/twisted/internet/_sslverify.py
|
||||
"""
|
||||
|
||||
_acceptableProtocols: typing.List[bytes]
|
||||
|
||||
def verify_callback(
|
||||
self,
|
||||
conn: OpenSSL.SSL.Connection,
|
||||
|
|
|
@ -6,6 +6,7 @@ import argparse
|
|||
import socket
|
||||
import ssl
|
||||
import sys
|
||||
import typing
|
||||
import urllib.parse
|
||||
|
||||
context = ssl.create_default_context()
|
||||
|
@ -13,7 +14,12 @@ context.check_hostname = False
|
|||
context.verify_mode = ssl.CERT_NONE
|
||||
|
||||
|
||||
def fetch(url, host=None, port=None, use_sni=False):
|
||||
def fetch(
|
||||
url: str,
|
||||
host: typing.Optional[str] = None,
|
||||
port: typing.Optional[int] = None,
|
||||
use_sni: bool = False,
|
||||
) -> None:
|
||||
parsed_url = urllib.parse.urlparse(url)
|
||||
if not parsed_url.scheme:
|
||||
parsed_url = urllib.parse.urlparse(f"gemini://{url}")
|
||||
|
@ -38,7 +44,7 @@ def fetch(url, host=None, port=None, use_sni=False):
|
|||
# ssock.unwrap()
|
||||
|
||||
|
||||
def run_client():
|
||||
def run_client() -> None:
|
||||
# fmt: off
|
||||
parser = argparse.ArgumentParser(description="A simple gemini client")
|
||||
parser.add_argument("url")
|
||||
|
@ -59,6 +65,7 @@ def run_client():
|
|||
context.set_alpn_protocols([args.tls_alpn_protocol])
|
||||
|
||||
if args.tls_keylog:
|
||||
# This is a "private" variable that the stdlib exposes for debugging
|
||||
context.keylog_filename = args.tls_keylog
|
||||
|
||||
fetch(args.url, args.host, args.port, args.tls_enable_sni)
|
||||
|
|
Loading…
Reference in New Issue