Type hints (#49)
* Better type hint coverage * linting fixes * Keep the old ResponseType name for backwards compatibility * Remove unnecessary assert statement Co-authored-by: Michael Lazar <mlazar@doctorondemand.com>
This commit is contained in:
parent
637025c8c3
commit
135dbda878
|
@ -6,10 +6,13 @@
|
||||||
|
|
||||||
- Added support for international domain names using IDN encoding.
|
- Added support for international domain names using IDN encoding.
|
||||||
|
|
||||||
|
#### New Features
|
||||||
|
|
||||||
|
- Several fixes & improvements to python type hinting coverage.
|
||||||
|
- Added a ``py.typed`` file to indicate project support for type hints.
|
||||||
|
|
||||||
#### Bug Fixes
|
#### Bug Fixes
|
||||||
|
|
||||||
- Added py.typed file to indicate that the jetforce python library has support
|
|
||||||
for type hints.
|
|
||||||
- Fixed a bug where TLS_CLIENT_AUTHORISED would sometimes be set to
|
- Fixed a bug where TLS_CLIENT_AUTHORISED would sometimes be set to
|
||||||
``True``/``False`` instead of ``1``/``0``.
|
``True``/``False`` instead of ``1``/``0``.
|
||||||
|
|
||||||
|
|
|
@ -190,7 +190,7 @@ additional modification by the server.
|
||||||
<dl>
|
<dl>
|
||||||
<dt>GATEWAY_INTERFACE</dt>
|
<dt>GATEWAY_INTERFACE</dt>
|
||||||
<dd>
|
<dd>
|
||||||
CGI version (for compatability with RFC 3785).<br>
|
CGI version (for compatibility with RFC 3785).<br>
|
||||||
<em>Example: "CGI/1.1"</em>
|
<em>Example: "CGI/1.1"</em>
|
||||||
</dd>
|
</dd>
|
||||||
|
|
||||||
|
@ -270,7 +270,7 @@ Additional CGI variables will be included only when the client connection uses a
|
||||||
|
|
||||||
<dt>AUTH_TYPE</dt>
|
<dt>AUTH_TYPE</dt>
|
||||||
<dd>
|
<dd>
|
||||||
Authentication type (for compatability with RFC 3785).<br>
|
Authentication type (for compatibility with RFC 3785).<br>
|
||||||
<em>Example: "CERTIFICATE"</em>
|
<em>Example: "CERTIFICATE"</em>
|
||||||
</dd>
|
</dd>
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
#!/usr/local/bin/python3.7
|
#!/usr/local/bin/python3.7
|
||||||
"""
|
r"""
|
||||||
CGI script that requests user supplied text using the INPUT status, and
|
CGI script that requests user supplied text using the INPUT status, and
|
||||||
pipes it into the `cowsay` program.
|
pipes it into the `cowsay` program.
|
||||||
|
|
||||||
|
|
|
@ -59,8 +59,8 @@ def deferred_counter():
|
||||||
eventually run in the main event loop.
|
eventually run in the main event loop.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def delayed_callback(x):
|
def delayed_callback(var):
|
||||||
return f"{x}\r\n"
|
return f"{var}\r\n"
|
||||||
|
|
||||||
for x in range(10):
|
for x in range(10):
|
||||||
yield deferLater(reactor, 1, delayed_callback, x)
|
yield deferLater(reactor, 1, delayed_callback, x)
|
||||||
|
|
|
@ -106,7 +106,7 @@ group.add_argument(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
rate_limiter = RateLimiter(args.rate_limit) if args.rate_limit else None
|
rate_limiter = RateLimiter(args.rate_limit) if args.rate_limit else None
|
||||||
app = StaticDirectoryApplication(
|
app = StaticDirectoryApplication(
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
|
@ -7,7 +9,13 @@ from urllib.parse import unquote, urlparse
|
||||||
|
|
||||||
from twisted.internet.defer import Deferred
|
from twisted.internet.defer import Deferred
|
||||||
|
|
||||||
|
EnvironDict = typing.Dict[str, object]
|
||||||
ResponseType = typing.Union[str, bytes, Deferred]
|
ResponseType = typing.Union[str, bytes, Deferred]
|
||||||
|
ApplicationResponse = typing.Iterable[ResponseType]
|
||||||
|
WriteStatusCallable = typing.Callable[[int, str], None]
|
||||||
|
ApplicationCallable = typing.Callable[
|
||||||
|
[EnvironDict, WriteStatusCallable], ApplicationResponse
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class Status:
|
class Status:
|
||||||
|
@ -45,9 +53,19 @@ class Request:
|
||||||
Object that encapsulates information about a single gemini request.
|
Object that encapsulates information about a single gemini request.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, environ: dict):
|
environ: EnvironDict
|
||||||
|
url: str
|
||||||
|
scheme: str
|
||||||
|
hostname: str
|
||||||
|
port: typing.Optional[int]
|
||||||
|
path: str
|
||||||
|
params: str
|
||||||
|
query: str
|
||||||
|
fragment: str
|
||||||
|
|
||||||
|
def __init__(self, environ: EnvironDict):
|
||||||
self.environ = environ
|
self.environ = environ
|
||||||
self.url = environ["GEMINI_URL"]
|
self.url = typing.cast(str, environ["GEMINI_URL"])
|
||||||
|
|
||||||
url_parts = urlparse(self.url)
|
url_parts = urlparse(self.url)
|
||||||
if not url_parts.hostname:
|
if not url_parts.hostname:
|
||||||
|
@ -84,7 +102,10 @@ class Response:
|
||||||
|
|
||||||
status: int
|
status: int
|
||||||
meta: str
|
meta: str
|
||||||
body: typing.Union[None, ResponseType, typing.Iterable[ResponseType]] = None
|
body: typing.Union[None, ResponseType, ApplicationResponse] = None
|
||||||
|
|
||||||
|
|
||||||
|
RouteHandler = typing.Callable[..., Response]
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
|
@ -101,7 +122,7 @@ class RoutePattern:
|
||||||
strict_port: bool = True
|
strict_port: bool = True
|
||||||
strict_trailing_slash: bool = False
|
strict_trailing_slash: bool = False
|
||||||
|
|
||||||
def match(self, request: Request) -> typing.Optional[re.Match]:
|
def match(self, request: Request) -> typing.Optional[re.Match[str]]:
|
||||||
"""
|
"""
|
||||||
Check if the given request URL matches this route pattern.
|
Check if the given request URL matches this route pattern.
|
||||||
"""
|
"""
|
||||||
|
@ -112,12 +133,12 @@ class RoutePattern:
|
||||||
server_port = request.environ["SERVER_PORT"]
|
server_port = request.environ["SERVER_PORT"]
|
||||||
|
|
||||||
if self.strict_hostname and request.hostname != server_hostname:
|
if self.strict_hostname and request.hostname != server_hostname:
|
||||||
return
|
return None
|
||||||
if self.strict_port and request.port is not None:
|
if self.strict_port and request.port is not None:
|
||||||
if request.port != server_port:
|
if request.port != server_port:
|
||||||
return
|
return None
|
||||||
if self.scheme and self.scheme != request.scheme:
|
if self.scheme and self.scheme != request.scheme:
|
||||||
return
|
return None
|
||||||
|
|
||||||
if self.strict_trailing_slash:
|
if self.strict_trailing_slash:
|
||||||
request_path = request.path
|
request_path = request.path
|
||||||
|
@ -127,9 +148,6 @@ class RoutePattern:
|
||||||
return re.fullmatch(self.path, request_path)
|
return re.fullmatch(self.path, request_path)
|
||||||
|
|
||||||
|
|
||||||
RouteHandler = typing.Callable[..., Response]
|
|
||||||
|
|
||||||
|
|
||||||
class RateLimiter:
|
class RateLimiter:
|
||||||
"""
|
"""
|
||||||
A class that can be used to apply rate-limiting to endpoints.
|
A class that can be used to apply rate-limiting to endpoints.
|
||||||
|
@ -144,6 +162,11 @@ class RateLimiter:
|
||||||
|
|
||||||
RE = re.compile("(?P<number>[0-9]+)/(?P<period>[0-9]+)?(?P<unit>[smhd])")
|
RE = re.compile("(?P<number>[0-9]+)/(?P<period>[0-9]+)?(?P<unit>[smhd])")
|
||||||
|
|
||||||
|
number: int
|
||||||
|
period: int
|
||||||
|
next_timestamp: float
|
||||||
|
rate_counter: typing.Dict[typing.Any, int]
|
||||||
|
|
||||||
def __init__(self, rate: str) -> None:
|
def __init__(self, rate: str) -> None:
|
||||||
match = self.RE.fullmatch(rate)
|
match = self.RE.fullmatch(rate)
|
||||||
if not match:
|
if not match:
|
||||||
|
@ -166,7 +189,7 @@ class RateLimiter:
|
||||||
self.next_timestamp = time.time() + self.period
|
self.next_timestamp = time.time() + self.period
|
||||||
self.rate_counter = defaultdict(int)
|
self.rate_counter = defaultdict(int)
|
||||||
|
|
||||||
def get_key(self, request: Request) -> typing.Optional[str]:
|
def get_key(self, request: Request) -> typing.Any:
|
||||||
"""
|
"""
|
||||||
Rate limit based on the client's IP-address.
|
Rate limit based on the client's IP-address.
|
||||||
"""
|
"""
|
||||||
|
@ -190,6 +213,8 @@ class RateLimiter:
|
||||||
msg = f"Rate limit exceeded, wait {time_left:.0f} seconds."
|
msg = f"Rate limit exceeded, wait {time_left:.0f} seconds."
|
||||||
return Response(Status.SLOW_DOWN, msg)
|
return Response(Status.SLOW_DOWN, msg)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
def apply(self, wrapped_func: RouteHandler) -> RouteHandler:
|
def apply(self, wrapped_func: RouteHandler) -> RouteHandler:
|
||||||
"""
|
"""
|
||||||
Decorator to apply rate limiting to an individual application route.
|
Decorator to apply rate limiting to an individual application route.
|
||||||
|
@ -203,7 +228,7 @@ class RateLimiter:
|
||||||
return Response(Status.SUCCESS, "text/gemini", "hello world!")
|
return Response(Status.SUCCESS, "text/gemini", "hello world!")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def wrapper(request: Request, **kwargs) -> Response:
|
def wrapper(request: Request, **kwargs: typing.Any) -> Response:
|
||||||
response = self.check(request)
|
response = self.check(request)
|
||||||
if response:
|
if response:
|
||||||
return response
|
return response
|
||||||
|
@ -224,13 +249,16 @@ class JetforceApplication:
|
||||||
how to accomplish this.
|
how to accomplish this.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
rate_limiter: typing.Optional[RateLimiter]
|
||||||
|
routes: typing.List[typing.Tuple[RoutePattern, RouteHandler]]
|
||||||
|
|
||||||
def __init__(self, rate_limiter: typing.Optional[RateLimiter] = None):
|
def __init__(self, rate_limiter: typing.Optional[RateLimiter] = None):
|
||||||
self.rate_limiter = rate_limiter
|
self.rate_limiter = rate_limiter
|
||||||
self.routes: typing.List[typing.Tuple[RoutePattern, RouteHandler]] = []
|
self.routes = []
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, environ: dict, send_status: typing.Callable
|
self, environ: EnvironDict, send_status: WriteStatusCallable
|
||||||
) -> typing.Iterator[ResponseType]:
|
) -> ApplicationResponse:
|
||||||
try:
|
try:
|
||||||
request = Request(environ)
|
request = Request(environ)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -245,7 +273,7 @@ class JetforceApplication:
|
||||||
|
|
||||||
for route_pattern, callback in self.routes[::-1]:
|
for route_pattern, callback in self.routes[::-1]:
|
||||||
match = route_pattern.match(request)
|
match = route_pattern.match(request)
|
||||||
if route_pattern.match(request):
|
if match:
|
||||||
callback_kwargs = match.groupdict()
|
callback_kwargs = match.groupdict()
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
@ -267,7 +295,7 @@ class JetforceApplication:
|
||||||
hostname: typing.Optional[str] = None,
|
hostname: typing.Optional[str] = None,
|
||||||
strict_hostname: bool = True,
|
strict_hostname: bool = True,
|
||||||
strict_trailing_slash: bool = False,
|
strict_trailing_slash: bool = False,
|
||||||
) -> typing.Callable:
|
) -> typing.Callable[[RouteHandler], RouteHandler]:
|
||||||
"""
|
"""
|
||||||
Decorator for binding a function to a route based on the URL path.
|
Decorator for binding a function to a route based on the URL path.
|
||||||
|
|
||||||
|
@ -287,7 +315,7 @@ class JetforceApplication:
|
||||||
|
|
||||||
return wrap
|
return wrap
|
||||||
|
|
||||||
def default_callback(self, request: Request, **_) -> Response:
|
def default_callback(self, request: Request, **_: typing.Any) -> Response:
|
||||||
"""
|
"""
|
||||||
Set the error response based on the URL type.
|
Set the error response based on the URL type.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,6 +1,15 @@
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from .base import Request, ResponseType, Status
|
from .base import (
|
||||||
|
ApplicationCallable,
|
||||||
|
ApplicationResponse,
|
||||||
|
EnvironDict,
|
||||||
|
Request,
|
||||||
|
Status,
|
||||||
|
WriteStatusCallable,
|
||||||
|
)
|
||||||
|
|
||||||
|
ApplicationMap = typing.Dict[typing.Optional[str], ApplicationCallable]
|
||||||
|
|
||||||
|
|
||||||
class CompositeApplication:
|
class CompositeApplication:
|
||||||
|
@ -11,7 +20,7 @@ class CompositeApplication:
|
||||||
two or more applications behind a single jetforce server.
|
two or more applications behind a single jetforce server.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, application_map: typing.Dict[typing.Optional[str], typing.Any]):
|
def __init__(self, application_map: ApplicationMap):
|
||||||
"""
|
"""
|
||||||
Initialize the application by providing a mapping of hostname -> app
|
Initialize the application by providing a mapping of hostname -> app
|
||||||
key pairs. A hostname of `None` is a special key that can be used as
|
key pairs. A hostname of `None` is a special key that can be used as
|
||||||
|
@ -29,8 +38,8 @@ class CompositeApplication:
|
||||||
self.application_map = application_map
|
self.application_map = application_map
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, environ: dict, send_status: typing.Callable
|
self, environ: EnvironDict, send_status: WriteStatusCallable
|
||||||
) -> typing.Iterator[ResponseType]:
|
) -> ApplicationResponse:
|
||||||
try:
|
try:
|
||||||
request = Request(environ)
|
request = Request(environ)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
|
@ -7,6 +7,7 @@ import typing
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
|
|
||||||
from .base import (
|
from .base import (
|
||||||
|
EnvironDict,
|
||||||
JetforceApplication,
|
JetforceApplication,
|
||||||
RateLimiter,
|
RateLimiter,
|
||||||
Request,
|
Request,
|
||||||
|
@ -33,6 +34,8 @@ class StaticDirectoryApplication(JetforceApplication):
|
||||||
# Chunk size for streaming files, taken from the twisted FileSender class
|
# Chunk size for streaming files, taken from the twisted FileSender class
|
||||||
CHUNK_SIZE = 2 ** 14
|
CHUNK_SIZE = 2 ** 14
|
||||||
|
|
||||||
|
mimetypes: mimetypes.MimeTypes
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
root_directory: str = "/var/gemini",
|
root_directory: str = "/var/gemini",
|
||||||
|
@ -57,8 +60,9 @@ class StaticDirectoryApplication(JetforceApplication):
|
||||||
if os.path.isfile(fn):
|
if os.path.isfile(fn):
|
||||||
self.mimetypes.read(fn)
|
self.mimetypes.read(fn)
|
||||||
|
|
||||||
self.mimetypes.add_type("text/gemini", ".gmi")
|
# This is a valid method but the type stubs are incorrect
|
||||||
self.mimetypes.add_type("text/gemini", ".gemini")
|
self.mimetypes.add_type("text/gemini", ".gmi") # type: ignore
|
||||||
|
self.mimetypes.add_type("text/gemini", ".gemini") # type: ignore
|
||||||
|
|
||||||
def serve_static_file(self, request: Request) -> Response:
|
def serve_static_file(self, request: Request) -> Response:
|
||||||
"""
|
"""
|
||||||
|
@ -143,7 +147,9 @@ class StaticDirectoryApplication(JetforceApplication):
|
||||||
else:
|
else:
|
||||||
return Response(Status.NOT_FOUND, "Not Found")
|
return Response(Status.NOT_FOUND, "Not Found")
|
||||||
|
|
||||||
def run_cgi_script(self, filesystem_path: pathlib.Path, environ: dict) -> Response:
|
def run_cgi_script(
|
||||||
|
self, filesystem_path: typing.Union[str, pathlib.Path], environ: EnvironDict
|
||||||
|
) -> Response:
|
||||||
"""
|
"""
|
||||||
Execute the given file as a CGI script and return the script's stdout
|
Execute the given file as a CGI script and return the script's stdout
|
||||||
stream to the client.
|
stream to the client.
|
||||||
|
@ -224,7 +230,7 @@ class StaticDirectoryApplication(JetforceApplication):
|
||||||
meta += f"; lang={self.default_lang}"
|
meta += f"; lang={self.default_lang}"
|
||||||
return meta
|
return meta
|
||||||
|
|
||||||
def default_callback(self, request: Request, **_) -> Response:
|
def default_callback(self, request: Request, **_: typing.Any) -> Response:
|
||||||
"""
|
"""
|
||||||
Since the StaticDirectoryApplication only serves gemini URLs, return
|
Since the StaticDirectoryApplication only serves gemini URLs, return
|
||||||
a proxy request refused for suspicious URLs.
|
a proxy request refused for suspicious URLs.
|
||||||
|
|
|
@ -11,9 +11,10 @@ from twisted.internet.error import ConnectionClosed
|
||||||
from twisted.internet.protocol import connectionDone
|
from twisted.internet.protocol import connectionDone
|
||||||
from twisted.internet.task import deferLater
|
from twisted.internet.task import deferLater
|
||||||
from twisted.protocols.basic import LineOnlyReceiver
|
from twisted.protocols.basic import LineOnlyReceiver
|
||||||
|
from twisted.python.failure import Failure
|
||||||
|
|
||||||
from .__version__ import __version__
|
from .__version__ import __version__
|
||||||
from .app.base import JetforceApplication, Status
|
from .app.base import ApplicationCallable, EnvironDict, Status
|
||||||
from .tls import inspect_certificate
|
from .tls import inspect_certificate
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
if typing.TYPE_CHECKING:
|
||||||
|
@ -49,12 +50,12 @@ class GeminiProtocol(LineOnlyReceiver):
|
||||||
response_buffer: str
|
response_buffer: str
|
||||||
response_size: int
|
response_size: int
|
||||||
|
|
||||||
def __init__(self, server: GeminiServer, app: JetforceApplication):
|
def __init__(self, server: GeminiServer, app: ApplicationCallable):
|
||||||
self.server = server
|
self.server = server
|
||||||
self.app = app
|
self.app = app
|
||||||
self._currently_deferred: typing.Optional[Deferred] = None
|
self._currently_deferred: typing.Optional[Deferred] = None
|
||||||
|
|
||||||
def connectionMade(self):
|
def connectionMade(self) -> None:
|
||||||
"""
|
"""
|
||||||
This is invoked by twisted after the connection is first established.
|
This is invoked by twisted after the connection is first established.
|
||||||
"""
|
"""
|
||||||
|
@ -63,7 +64,7 @@ class GeminiProtocol(LineOnlyReceiver):
|
||||||
self.response_buffer = ""
|
self.response_buffer = ""
|
||||||
self.client_addr = self.transport.getPeer()
|
self.client_addr = self.transport.getPeer()
|
||||||
|
|
||||||
def connectionLost(self, reason=connectionDone):
|
def connectionLost(self, reason: Failure = connectionDone) -> None:
|
||||||
"""
|
"""
|
||||||
This is invoked by twisted after the connection has been closed.
|
This is invoked by twisted after the connection has been closed.
|
||||||
"""
|
"""
|
||||||
|
@ -71,20 +72,20 @@ class GeminiProtocol(LineOnlyReceiver):
|
||||||
self._currently_deferred.errback(reason)
|
self._currently_deferred.errback(reason)
|
||||||
self._currently_deferred = None
|
self._currently_deferred = None
|
||||||
|
|
||||||
def lineReceived(self, line):
|
def lineReceived(self, line: bytes) -> Deferred:
|
||||||
"""
|
"""
|
||||||
This method is invoked by LineOnlyReceiver for every incoming line.
|
This method is invoked by LineOnlyReceiver for every incoming line.
|
||||||
"""
|
"""
|
||||||
self.request = line
|
self.request = line
|
||||||
return ensureDeferred(self._handle_request_noblock())
|
return ensureDeferred(self._handle_request_noblock())
|
||||||
|
|
||||||
def lineLengthExceeded(self, line):
|
def lineLengthExceeded(self, line: bytes) -> None:
|
||||||
"""
|
"""
|
||||||
Called when the maximum line length has been reached.
|
Called when the maximum line length has been reached.
|
||||||
"""
|
"""
|
||||||
return self.finish_connection()
|
self.finish_connection()
|
||||||
|
|
||||||
def finish_connection(self):
|
def finish_connection(self) -> None:
|
||||||
"""
|
"""
|
||||||
Send the TLS "close_notify" alert and then immediately close the TCP
|
Send the TLS "close_notify" alert and then immediately close the TCP
|
||||||
connection without waiting for the client to respond with it's own
|
connection without waiting for the client to respond with it's own
|
||||||
|
@ -111,7 +112,7 @@ class GeminiProtocol(LineOnlyReceiver):
|
||||||
# part of the above TLS shutdown.
|
# part of the above TLS shutdown.
|
||||||
self.transport.transport.loseConnection()
|
self.transport.transport.loseConnection()
|
||||||
|
|
||||||
async def _handle_request_noblock(self):
|
async def _handle_request_noblock(self) -> None:
|
||||||
"""
|
"""
|
||||||
Handle the gemini request and write the raw response to the socket.
|
Handle the gemini request and write the raw response to the socket.
|
||||||
|
|
||||||
|
@ -170,14 +171,14 @@ class GeminiProtocol(LineOnlyReceiver):
|
||||||
self.log_request()
|
self.log_request()
|
||||||
self.finish_connection()
|
self.finish_connection()
|
||||||
|
|
||||||
async def track_deferred(self, deferred: Deferred):
|
async def track_deferred(self, deferred: Deferred) -> typing.Union[str, bytes]:
|
||||||
self._currently_deferred = deferred
|
self._currently_deferred = deferred
|
||||||
try:
|
try:
|
||||||
return await deferred
|
return await deferred
|
||||||
finally:
|
finally:
|
||||||
self._currently_deferred = None
|
self._currently_deferred = None
|
||||||
|
|
||||||
def build_environ(self) -> typing.Dict[str, typing.Any]:
|
def build_environ(self) -> EnvironDict:
|
||||||
"""
|
"""
|
||||||
Construct a dictionary that will be passed to the application handler.
|
Construct a dictionary that will be passed to the application handler.
|
||||||
|
|
||||||
|
|
|
@ -4,13 +4,14 @@ import socket
|
||||||
import sys
|
import sys
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from twisted.internet import reactor
|
from twisted.internet import reactor as _reactor
|
||||||
from twisted.internet.base import ReactorBase
|
from twisted.internet.base import ReactorBase
|
||||||
from twisted.internet.endpoints import SSL4ServerEndpoint
|
from twisted.internet.endpoints import SSL4ServerEndpoint
|
||||||
from twisted.internet.protocol import Factory
|
from twisted.internet.protocol import Factory
|
||||||
from twisted.internet.tcp import Port
|
from twisted.internet.tcp import Port
|
||||||
|
|
||||||
from .__version__ import __version__
|
from .__version__ import __version__
|
||||||
|
from .app.base import ApplicationCallable
|
||||||
from .protocol import GeminiProtocol
|
from .protocol import GeminiProtocol
|
||||||
from .tls import GeminiCertificateOptions, generate_ad_hoc_certificate
|
from .tls import GeminiCertificateOptions, generate_ad_hoc_certificate
|
||||||
|
|
||||||
|
@ -49,8 +50,8 @@ class GeminiServer(Factory):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
app: typing.Callable,
|
app: ApplicationCallable,
|
||||||
reactor: ReactorBase = reactor,
|
reactor: ReactorBase = _reactor,
|
||||||
host: str = "127.0.0.1",
|
host: str = "127.0.0.1",
|
||||||
port: int = 1965,
|
port: int = 1965,
|
||||||
hostname: str = "localhost",
|
hostname: str = "localhost",
|
||||||
|
@ -95,7 +96,7 @@ class GeminiServer(Factory):
|
||||||
else:
|
else:
|
||||||
self.log_message(f"Listening on [{sock_ip}]:{sock_port}")
|
self.log_message(f"Listening on [{sock_ip}]:{sock_port}")
|
||||||
|
|
||||||
def buildProtocol(self, addr) -> GeminiProtocol:
|
def buildProtocol(self, addr: typing.Any) -> GeminiProtocol:
|
||||||
"""
|
"""
|
||||||
This method is invoked by twisted once for every incoming connection.
|
This method is invoked by twisted once for every incoming connection.
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,7 @@ from twisted.python.randbytes import secureRandom
|
||||||
COMMON_NAME = x509.NameOID.COMMON_NAME
|
COMMON_NAME = x509.NameOID.COMMON_NAME
|
||||||
|
|
||||||
|
|
||||||
def inspect_certificate(cert: x509) -> dict:
|
def inspect_certificate(cert: x509.Certificate) -> typing.Dict[str, object]:
|
||||||
"""
|
"""
|
||||||
Extract useful fields from a x509 client certificate object.
|
Extract useful fields from a x509 client certificate object.
|
||||||
"""
|
"""
|
||||||
|
@ -66,7 +66,7 @@ def generate_ad_hoc_certificate(hostname: str) -> typing.Tuple[str, str]:
|
||||||
subject_name = x509.Name([common_name])
|
subject_name = x509.Name([common_name])
|
||||||
not_valid_before = datetime.datetime.utcnow()
|
not_valid_before = datetime.datetime.utcnow()
|
||||||
not_valid_after = not_valid_before + datetime.timedelta(days=365)
|
not_valid_after = not_valid_before + datetime.timedelta(days=365)
|
||||||
certificate = x509.CertificateBuilder(
|
cert_builder = x509.CertificateBuilder(
|
||||||
subject_name=subject_name,
|
subject_name=subject_name,
|
||||||
issuer_name=subject_name,
|
issuer_name=subject_name,
|
||||||
public_key=private_key.public_key(),
|
public_key=private_key.public_key(),
|
||||||
|
@ -74,7 +74,7 @@ def generate_ad_hoc_certificate(hostname: str) -> typing.Tuple[str, str]:
|
||||||
not_valid_before=not_valid_before,
|
not_valid_before=not_valid_before,
|
||||||
not_valid_after=not_valid_after,
|
not_valid_after=not_valid_after,
|
||||||
)
|
)
|
||||||
certificate = certificate.sign(private_key, hashes.SHA256(), backend)
|
certificate = cert_builder.sign(private_key, hashes.SHA256(), backend)
|
||||||
with open(certfile, "wb") as fp:
|
with open(certfile, "wb") as fp:
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
cert_data = certificate.public_bytes(serialization.Encoding.PEM)
|
cert_data = certificate.public_bytes(serialization.Encoding.PEM)
|
||||||
|
@ -96,6 +96,8 @@ class GeminiCertificateOptions(CertificateOptions):
|
||||||
https://github.com/twisted/twisted/blob/trunk/src/twisted/internet/_sslverify.py
|
https://github.com/twisted/twisted/blob/trunk/src/twisted/internet/_sslverify.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_acceptableProtocols: typing.List[bytes]
|
||||||
|
|
||||||
def verify_callback(
|
def verify_callback(
|
||||||
self,
|
self,
|
||||||
conn: OpenSSL.SSL.Connection,
|
conn: OpenSSL.SSL.Connection,
|
||||||
|
|
|
@ -6,6 +6,7 @@ import argparse
|
||||||
import socket
|
import socket
|
||||||
import ssl
|
import ssl
|
||||||
import sys
|
import sys
|
||||||
|
import typing
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
|
|
||||||
context = ssl.create_default_context()
|
context = ssl.create_default_context()
|
||||||
|
@ -13,7 +14,12 @@ context.check_hostname = False
|
||||||
context.verify_mode = ssl.CERT_NONE
|
context.verify_mode = ssl.CERT_NONE
|
||||||
|
|
||||||
|
|
||||||
def fetch(url, host=None, port=None, use_sni=False):
|
def fetch(
|
||||||
|
url: str,
|
||||||
|
host: typing.Optional[str] = None,
|
||||||
|
port: typing.Optional[int] = None,
|
||||||
|
use_sni: bool = False,
|
||||||
|
) -> None:
|
||||||
parsed_url = urllib.parse.urlparse(url)
|
parsed_url = urllib.parse.urlparse(url)
|
||||||
if not parsed_url.scheme:
|
if not parsed_url.scheme:
|
||||||
parsed_url = urllib.parse.urlparse(f"gemini://{url}")
|
parsed_url = urllib.parse.urlparse(f"gemini://{url}")
|
||||||
|
@ -38,7 +44,7 @@ def fetch(url, host=None, port=None, use_sni=False):
|
||||||
# ssock.unwrap()
|
# ssock.unwrap()
|
||||||
|
|
||||||
|
|
||||||
def run_client():
|
def run_client() -> None:
|
||||||
# fmt: off
|
# fmt: off
|
||||||
parser = argparse.ArgumentParser(description="A simple gemini client")
|
parser = argparse.ArgumentParser(description="A simple gemini client")
|
||||||
parser.add_argument("url")
|
parser.add_argument("url")
|
||||||
|
@ -59,6 +65,7 @@ def run_client():
|
||||||
context.set_alpn_protocols([args.tls_alpn_protocol])
|
context.set_alpn_protocols([args.tls_alpn_protocol])
|
||||||
|
|
||||||
if args.tls_keylog:
|
if args.tls_keylog:
|
||||||
|
# This is a "private" variable that the stdlib exposes for debugging
|
||||||
context.keylog_filename = args.tls_keylog
|
context.keylog_filename = args.tls_keylog
|
||||||
|
|
||||||
fetch(args.url, args.host, args.port, args.tls_enable_sni)
|
fetch(args.url, args.host, args.port, args.tls_enable_sni)
|
||||||
|
|
Loading…
Reference in New Issue