diff --git a/CHANGELOG.md b/CHANGELOG.md
index 16276fb..d17e62a 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -6,10 +6,13 @@
- Added support for international domain names using IDN encoding.
+#### New Features
+
+- Several fixes & improvements to python type hinting coverage.
+- Added a ``py.typed`` file to indicate project support for type hints.
+
#### Bug Fixes
-- Added py.typed file to indicate that the jetforce python library has support
- for type hints.
- Fixed a bug where TLS_CLIENT_AUTHORISED would sometimes be set to
``True``/``False`` instead of ``1``/``0``.
diff --git a/README.md b/README.md
index 98d0e82..8ef4e4c 100644
--- a/README.md
+++ b/README.md
@@ -190,7 +190,7 @@ additional modification by the server.
- GATEWAY_INTERFACE
-
- CGI version (for compatability with RFC 3785).
+ CGI version (for compatibility with RFC 3785).
Example: "CGI/1.1"
@@ -270,7 +270,7 @@ Additional CGI variables will be included only when the client connection uses a
- AUTH_TYPE
-
- Authentication type (for compatability with RFC 3785).
+ Authentication type (for compatibility with RFC 3785).
Example: "CERTIFICATE"
diff --git a/examples/cgi/cowsay.cgi b/examples/cgi/cowsay.cgi
index 07b6aaa..1a2cf61 100755
--- a/examples/cgi/cowsay.cgi
+++ b/examples/cgi/cowsay.cgi
@@ -1,5 +1,5 @@
#!/usr/local/bin/python3.7
-"""
+r"""
CGI script that requests user supplied text using the INPUT status, and
pipes it into the `cowsay` program.
diff --git a/examples/counter.py b/examples/counter.py
index 16ca794..786daf0 100644
--- a/examples/counter.py
+++ b/examples/counter.py
@@ -59,8 +59,8 @@ def deferred_counter():
eventually run in the main event loop.
"""
- def delayed_callback(x):
- return f"{x}\r\n"
+ def delayed_callback(var):
+ return f"{var}\r\n"
for x in range(10):
yield deferLater(reactor, 1, delayed_callback, x)
diff --git a/jetforce/__main__.py b/jetforce/__main__.py
index 1c1e775..6f03e14 100644
--- a/jetforce/__main__.py
+++ b/jetforce/__main__.py
@@ -106,7 +106,7 @@ group.add_argument(
)
-def main():
+def main() -> None:
args = parser.parse_args()
rate_limiter = RateLimiter(args.rate_limit) if args.rate_limit else None
app = StaticDirectoryApplication(
diff --git a/jetforce/app/base.py b/jetforce/app/base.py
index f9651ec..b147bb4 100644
--- a/jetforce/app/base.py
+++ b/jetforce/app/base.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import dataclasses
import re
import time
@@ -7,7 +9,13 @@ from urllib.parse import unquote, urlparse
from twisted.internet.defer import Deferred
+EnvironDict = typing.Dict[str, object]
ResponseType = typing.Union[str, bytes, Deferred]
+ApplicationResponse = typing.Iterable[ResponseType]
+WriteStatusCallable = typing.Callable[[int, str], None]
+ApplicationCallable = typing.Callable[
+ [EnvironDict, WriteStatusCallable], ApplicationResponse
+]
class Status:
@@ -45,9 +53,19 @@ class Request:
Object that encapsulates information about a single gemini request.
"""
- def __init__(self, environ: dict):
+ environ: EnvironDict
+ url: str
+ scheme: str
+ hostname: str
+ port: typing.Optional[int]
+ path: str
+ params: str
+ query: str
+ fragment: str
+
+ def __init__(self, environ: EnvironDict):
self.environ = environ
- self.url = environ["GEMINI_URL"]
+ self.url = typing.cast(str, environ["GEMINI_URL"])
url_parts = urlparse(self.url)
if not url_parts.hostname:
@@ -84,7 +102,10 @@ class Response:
status: int
meta: str
- body: typing.Union[None, ResponseType, typing.Iterable[ResponseType]] = None
+ body: typing.Union[None, ResponseType, ApplicationResponse] = None
+
+
+RouteHandler = typing.Callable[..., Response]
@dataclasses.dataclass
@@ -101,7 +122,7 @@ class RoutePattern:
strict_port: bool = True
strict_trailing_slash: bool = False
- def match(self, request: Request) -> typing.Optional[re.Match]:
+ def match(self, request: Request) -> typing.Optional[re.Match[str]]:
"""
Check if the given request URL matches this route pattern.
"""
@@ -112,12 +133,12 @@ class RoutePattern:
server_port = request.environ["SERVER_PORT"]
if self.strict_hostname and request.hostname != server_hostname:
- return
+ return None
if self.strict_port and request.port is not None:
if request.port != server_port:
- return
+ return None
if self.scheme and self.scheme != request.scheme:
- return
+ return None
if self.strict_trailing_slash:
request_path = request.path
@@ -127,9 +148,6 @@ class RoutePattern:
return re.fullmatch(self.path, request_path)
-RouteHandler = typing.Callable[..., Response]
-
-
class RateLimiter:
"""
A class that can be used to apply rate-limiting to endpoints.
@@ -144,6 +162,11 @@ class RateLimiter:
RE = re.compile("(?P[0-9]+)/(?P[0-9]+)?(?P[smhd])")
+ number: int
+ period: int
+ next_timestamp: float
+ rate_counter: typing.Dict[typing.Any, int]
+
def __init__(self, rate: str) -> None:
match = self.RE.fullmatch(rate)
if not match:
@@ -166,7 +189,7 @@ class RateLimiter:
self.next_timestamp = time.time() + self.period
self.rate_counter = defaultdict(int)
- def get_key(self, request: Request) -> typing.Optional[str]:
+ def get_key(self, request: Request) -> typing.Any:
"""
Rate limit based on the client's IP-address.
"""
@@ -190,6 +213,8 @@ class RateLimiter:
msg = f"Rate limit exceeded, wait {time_left:.0f} seconds."
return Response(Status.SLOW_DOWN, msg)
+ return None
+
def apply(self, wrapped_func: RouteHandler) -> RouteHandler:
"""
Decorator to apply rate limiting to an individual application route.
@@ -203,7 +228,7 @@ class RateLimiter:
return Response(Status.SUCCESS, "text/gemini", "hello world!")
"""
- def wrapper(request: Request, **kwargs) -> Response:
+ def wrapper(request: Request, **kwargs: typing.Any) -> Response:
response = self.check(request)
if response:
return response
@@ -224,13 +249,16 @@ class JetforceApplication:
how to accomplish this.
"""
+ rate_limiter: typing.Optional[RateLimiter]
+ routes: typing.List[typing.Tuple[RoutePattern, RouteHandler]]
+
def __init__(self, rate_limiter: typing.Optional[RateLimiter] = None):
self.rate_limiter = rate_limiter
- self.routes: typing.List[typing.Tuple[RoutePattern, RouteHandler]] = []
+ self.routes = []
def __call__(
- self, environ: dict, send_status: typing.Callable
- ) -> typing.Iterator[ResponseType]:
+ self, environ: EnvironDict, send_status: WriteStatusCallable
+ ) -> ApplicationResponse:
try:
request = Request(environ)
except Exception:
@@ -245,7 +273,7 @@ class JetforceApplication:
for route_pattern, callback in self.routes[::-1]:
match = route_pattern.match(request)
- if route_pattern.match(request):
+ if match:
callback_kwargs = match.groupdict()
break
else:
@@ -267,7 +295,7 @@ class JetforceApplication:
hostname: typing.Optional[str] = None,
strict_hostname: bool = True,
strict_trailing_slash: bool = False,
- ) -> typing.Callable:
+ ) -> typing.Callable[[RouteHandler], RouteHandler]:
"""
Decorator for binding a function to a route based on the URL path.
@@ -287,7 +315,7 @@ class JetforceApplication:
return wrap
- def default_callback(self, request: Request, **_) -> Response:
+ def default_callback(self, request: Request, **_: typing.Any) -> Response:
"""
Set the error response based on the URL type.
"""
diff --git a/jetforce/app/composite.py b/jetforce/app/composite.py
index 94940de..dccc7d5 100644
--- a/jetforce/app/composite.py
+++ b/jetforce/app/composite.py
@@ -1,6 +1,15 @@
import typing
-from .base import Request, ResponseType, Status
+from .base import (
+ ApplicationCallable,
+ ApplicationResponse,
+ EnvironDict,
+ Request,
+ Status,
+ WriteStatusCallable,
+)
+
+ApplicationMap = typing.Dict[typing.Optional[str], ApplicationCallable]
class CompositeApplication:
@@ -11,7 +20,7 @@ class CompositeApplication:
two or more applications behind a single jetforce server.
"""
- def __init__(self, application_map: typing.Dict[typing.Optional[str], typing.Any]):
+ def __init__(self, application_map: ApplicationMap):
"""
Initialize the application by providing a mapping of hostname -> app
key pairs. A hostname of `None` is a special key that can be used as
@@ -29,8 +38,8 @@ class CompositeApplication:
self.application_map = application_map
def __call__(
- self, environ: dict, send_status: typing.Callable
- ) -> typing.Iterator[ResponseType]:
+ self, environ: EnvironDict, send_status: WriteStatusCallable
+ ) -> ApplicationResponse:
try:
request = Request(environ)
except Exception:
diff --git a/jetforce/app/static.py b/jetforce/app/static.py
index dbb6cbe..5ff3097 100644
--- a/jetforce/app/static.py
+++ b/jetforce/app/static.py
@@ -7,6 +7,7 @@ import typing
import urllib.parse
from .base import (
+ EnvironDict,
JetforceApplication,
RateLimiter,
Request,
@@ -33,6 +34,8 @@ class StaticDirectoryApplication(JetforceApplication):
# Chunk size for streaming files, taken from the twisted FileSender class
CHUNK_SIZE = 2 ** 14
+ mimetypes: mimetypes.MimeTypes
+
def __init__(
self,
root_directory: str = "/var/gemini",
@@ -57,8 +60,9 @@ class StaticDirectoryApplication(JetforceApplication):
if os.path.isfile(fn):
self.mimetypes.read(fn)
- self.mimetypes.add_type("text/gemini", ".gmi")
- self.mimetypes.add_type("text/gemini", ".gemini")
+ # This is a valid method but the type stubs are incorrect
+ self.mimetypes.add_type("text/gemini", ".gmi") # type: ignore
+ self.mimetypes.add_type("text/gemini", ".gemini") # type: ignore
def serve_static_file(self, request: Request) -> Response:
"""
@@ -143,7 +147,9 @@ class StaticDirectoryApplication(JetforceApplication):
else:
return Response(Status.NOT_FOUND, "Not Found")
- def run_cgi_script(self, filesystem_path: pathlib.Path, environ: dict) -> Response:
+ def run_cgi_script(
+ self, filesystem_path: typing.Union[str, pathlib.Path], environ: EnvironDict
+ ) -> Response:
"""
Execute the given file as a CGI script and return the script's stdout
stream to the client.
@@ -224,7 +230,7 @@ class StaticDirectoryApplication(JetforceApplication):
meta += f"; lang={self.default_lang}"
return meta
- def default_callback(self, request: Request, **_) -> Response:
+ def default_callback(self, request: Request, **_: typing.Any) -> Response:
"""
Since the StaticDirectoryApplication only serves gemini URLs, return
a proxy request refused for suspicious URLs.
diff --git a/jetforce/protocol.py b/jetforce/protocol.py
index 1a65265..99b36a9 100644
--- a/jetforce/protocol.py
+++ b/jetforce/protocol.py
@@ -11,9 +11,10 @@ from twisted.internet.error import ConnectionClosed
from twisted.internet.protocol import connectionDone
from twisted.internet.task import deferLater
from twisted.protocols.basic import LineOnlyReceiver
+from twisted.python.failure import Failure
from .__version__ import __version__
-from .app.base import JetforceApplication, Status
+from .app.base import ApplicationCallable, EnvironDict, Status
from .tls import inspect_certificate
if typing.TYPE_CHECKING:
@@ -49,12 +50,12 @@ class GeminiProtocol(LineOnlyReceiver):
response_buffer: str
response_size: int
- def __init__(self, server: GeminiServer, app: JetforceApplication):
+ def __init__(self, server: GeminiServer, app: ApplicationCallable):
self.server = server
self.app = app
self._currently_deferred: typing.Optional[Deferred] = None
- def connectionMade(self):
+ def connectionMade(self) -> None:
"""
This is invoked by twisted after the connection is first established.
"""
@@ -63,7 +64,7 @@ class GeminiProtocol(LineOnlyReceiver):
self.response_buffer = ""
self.client_addr = self.transport.getPeer()
- def connectionLost(self, reason=connectionDone):
+ def connectionLost(self, reason: Failure = connectionDone) -> None:
"""
This is invoked by twisted after the connection has been closed.
"""
@@ -71,20 +72,20 @@ class GeminiProtocol(LineOnlyReceiver):
self._currently_deferred.errback(reason)
self._currently_deferred = None
- def lineReceived(self, line):
+ def lineReceived(self, line: bytes) -> Deferred:
"""
This method is invoked by LineOnlyReceiver for every incoming line.
"""
self.request = line
return ensureDeferred(self._handle_request_noblock())
- def lineLengthExceeded(self, line):
+ def lineLengthExceeded(self, line: bytes) -> None:
"""
Called when the maximum line length has been reached.
"""
- return self.finish_connection()
+ self.finish_connection()
- def finish_connection(self):
+ def finish_connection(self) -> None:
"""
Send the TLS "close_notify" alert and then immediately close the TCP
connection without waiting for the client to respond with it's own
@@ -111,7 +112,7 @@ class GeminiProtocol(LineOnlyReceiver):
# part of the above TLS shutdown.
self.transport.transport.loseConnection()
- async def _handle_request_noblock(self):
+ async def _handle_request_noblock(self) -> None:
"""
Handle the gemini request and write the raw response to the socket.
@@ -170,14 +171,14 @@ class GeminiProtocol(LineOnlyReceiver):
self.log_request()
self.finish_connection()
- async def track_deferred(self, deferred: Deferred):
+ async def track_deferred(self, deferred: Deferred) -> typing.Union[str, bytes]:
self._currently_deferred = deferred
try:
return await deferred
finally:
self._currently_deferred = None
- def build_environ(self) -> typing.Dict[str, typing.Any]:
+ def build_environ(self) -> EnvironDict:
"""
Construct a dictionary that will be passed to the application handler.
diff --git a/jetforce/server.py b/jetforce/server.py
index 092ac99..41916f3 100644
--- a/jetforce/server.py
+++ b/jetforce/server.py
@@ -4,13 +4,14 @@ import socket
import sys
import typing
-from twisted.internet import reactor
+from twisted.internet import reactor as _reactor
from twisted.internet.base import ReactorBase
from twisted.internet.endpoints import SSL4ServerEndpoint
from twisted.internet.protocol import Factory
from twisted.internet.tcp import Port
from .__version__ import __version__
+from .app.base import ApplicationCallable
from .protocol import GeminiProtocol
from .tls import GeminiCertificateOptions, generate_ad_hoc_certificate
@@ -49,8 +50,8 @@ class GeminiServer(Factory):
def __init__(
self,
- app: typing.Callable,
- reactor: ReactorBase = reactor,
+ app: ApplicationCallable,
+ reactor: ReactorBase = _reactor,
host: str = "127.0.0.1",
port: int = 1965,
hostname: str = "localhost",
@@ -95,7 +96,7 @@ class GeminiServer(Factory):
else:
self.log_message(f"Listening on [{sock_ip}]:{sock_port}")
- def buildProtocol(self, addr) -> GeminiProtocol:
+ def buildProtocol(self, addr: typing.Any) -> GeminiProtocol:
"""
This method is invoked by twisted once for every incoming connection.
diff --git a/jetforce/tls.py b/jetforce/tls.py
index 745b320..b8315a8 100644
--- a/jetforce/tls.py
+++ b/jetforce/tls.py
@@ -15,7 +15,7 @@ from twisted.python.randbytes import secureRandom
COMMON_NAME = x509.NameOID.COMMON_NAME
-def inspect_certificate(cert: x509) -> dict:
+def inspect_certificate(cert: x509.Certificate) -> typing.Dict[str, object]:
"""
Extract useful fields from a x509 client certificate object.
"""
@@ -66,7 +66,7 @@ def generate_ad_hoc_certificate(hostname: str) -> typing.Tuple[str, str]:
subject_name = x509.Name([common_name])
not_valid_before = datetime.datetime.utcnow()
not_valid_after = not_valid_before + datetime.timedelta(days=365)
- certificate = x509.CertificateBuilder(
+ cert_builder = x509.CertificateBuilder(
subject_name=subject_name,
issuer_name=subject_name,
public_key=private_key.public_key(),
@@ -74,7 +74,7 @@ def generate_ad_hoc_certificate(hostname: str) -> typing.Tuple[str, str]:
not_valid_before=not_valid_before,
not_valid_after=not_valid_after,
)
- certificate = certificate.sign(private_key, hashes.SHA256(), backend)
+ certificate = cert_builder.sign(private_key, hashes.SHA256(), backend)
with open(certfile, "wb") as fp:
# noinspection PyTypeChecker
cert_data = certificate.public_bytes(serialization.Encoding.PEM)
@@ -96,6 +96,8 @@ class GeminiCertificateOptions(CertificateOptions):
https://github.com/twisted/twisted/blob/trunk/src/twisted/internet/_sslverify.py
"""
+ _acceptableProtocols: typing.List[bytes]
+
def verify_callback(
self,
conn: OpenSSL.SSL.Connection,
diff --git a/jetforce_client.py b/jetforce_client.py
index be0654f..8e8ecf7 100755
--- a/jetforce_client.py
+++ b/jetforce_client.py
@@ -6,6 +6,7 @@ import argparse
import socket
import ssl
import sys
+import typing
import urllib.parse
context = ssl.create_default_context()
@@ -13,7 +14,12 @@ context.check_hostname = False
context.verify_mode = ssl.CERT_NONE
-def fetch(url, host=None, port=None, use_sni=False):
+def fetch(
+ url: str,
+ host: typing.Optional[str] = None,
+ port: typing.Optional[int] = None,
+ use_sni: bool = False,
+) -> None:
parsed_url = urllib.parse.urlparse(url)
if not parsed_url.scheme:
parsed_url = urllib.parse.urlparse(f"gemini://{url}")
@@ -38,7 +44,7 @@ def fetch(url, host=None, port=None, use_sni=False):
# ssock.unwrap()
-def run_client():
+def run_client() -> None:
# fmt: off
parser = argparse.ArgumentParser(description="A simple gemini client")
parser.add_argument("url")
@@ -59,6 +65,7 @@ def run_client():
context.set_alpn_protocols([args.tls_alpn_protocol])
if args.tls_keylog:
+ # This is a "private" variable that the stdlib exposes for debugging
context.keylog_filename = args.tls_keylog
fetch(args.url, args.host, args.port, args.tls_enable_sni)
diff --git a/setup.py b/setup.py
index 899956b..b9efcc7 100644
--- a/setup.py
+++ b/setup.py
@@ -3,7 +3,7 @@ import codecs
import setuptools
-def long_description():
+def long_description() -> str:
with codecs.open("README.md", encoding="utf8") as f:
return f.read()