325 lines
9.3 KiB
Python
325 lines
9.3 KiB
Python
from __future__ import annotations
|
|
|
|
import dataclasses
|
|
import re
|
|
import time
|
|
import typing
|
|
from collections import defaultdict
|
|
from urllib.parse import unquote, urlparse
|
|
|
|
from twisted.internet.defer import Deferred
|
|
|
|
EnvironDict = typing.Dict[str, typing.Any]
|
|
ResponseType = typing.Union[str, bytes, Deferred]
|
|
ApplicationResponse = typing.Iterable[ResponseType]
|
|
WriteStatusCallable = typing.Callable[[int, str], None]
|
|
ApplicationCallable = typing.Callable[
|
|
[EnvironDict, WriteStatusCallable], ApplicationResponse
|
|
]
|
|
|
|
|
|
class Status:
|
|
"""
|
|
Gemini response status codes.
|
|
"""
|
|
|
|
INPUT = 10
|
|
SENSITIVE_INPUT = 11
|
|
|
|
SUCCESS = 20
|
|
|
|
REDIRECT_TEMPORARY = 30
|
|
REDIRECT_PERMANENT = 31
|
|
|
|
TEMPORARY_FAILURE = 40
|
|
SERVER_UNAVAILABLE = 41
|
|
CGI_ERROR = 42
|
|
PROXY_ERROR = 43
|
|
SLOW_DOWN = 44
|
|
|
|
PERMANENT_FAILURE = 50
|
|
NOT_FOUND = 51
|
|
GONE = 52
|
|
PROXY_REQUEST_REFUSED = 53
|
|
BAD_REQUEST = 59
|
|
|
|
CLIENT_CERTIFICATE_REQUIRED = 60
|
|
CERTIFICATE_NOT_AUTHORISED = 61
|
|
CERTIFICATE_NOT_VALID = 62
|
|
|
|
|
|
class Request:
|
|
"""
|
|
Object that encapsulates information about a single gemini request.
|
|
"""
|
|
|
|
environ: EnvironDict
|
|
url: str
|
|
scheme: str
|
|
hostname: str
|
|
port: typing.Optional[int]
|
|
path: str
|
|
params: str
|
|
query: str
|
|
fragment: str
|
|
|
|
def __init__(self, environ: EnvironDict):
|
|
self.environ = environ
|
|
self.url = typing.cast(str, environ["GEMINI_URL"])
|
|
|
|
url_parts = urlparse(self.url)
|
|
if not url_parts.hostname:
|
|
raise ValueError("Missing hostname component")
|
|
|
|
if not url_parts.scheme:
|
|
raise ValueError("Missing scheme component")
|
|
|
|
self.scheme = url_parts.scheme
|
|
|
|
# gemini://username@host/... is forbidden by the specification
|
|
if self.scheme == "gemini" and url_parts.username:
|
|
raise ValueError("Invalid userinfo component")
|
|
|
|
# Convert domain names to punycode for compatibility with URLs that
|
|
# contain encoded IDNs (follows RFC 3490).
|
|
hostname = url_parts.hostname
|
|
hostname = hostname.encode("idna").decode("ascii")
|
|
|
|
self.hostname = hostname
|
|
self.port = url_parts.port
|
|
|
|
self.path = unquote(url_parts.path)
|
|
self.params = unquote(url_parts.params)
|
|
self.query = unquote(url_parts.query)
|
|
self.fragment = unquote(url_parts.fragment)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class Response:
|
|
"""
|
|
Object that encapsulates information about a single gemini response.
|
|
"""
|
|
|
|
status: int
|
|
meta: str
|
|
body: typing.Union[None, ResponseType, ApplicationResponse] = None
|
|
|
|
|
|
RouteHandler = typing.Callable[..., Response]
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class RoutePattern:
|
|
"""
|
|
A pattern for matching URLs with a single endpoint or route.
|
|
"""
|
|
|
|
path: str = ".*"
|
|
scheme: str = "gemini"
|
|
hostname: typing.Optional[str] = None
|
|
|
|
strict_hostname: bool = True
|
|
strict_port: bool = True
|
|
strict_trailing_slash: bool = False
|
|
|
|
def match(self, request: Request) -> typing.Optional[re.Match[str]]:
|
|
"""
|
|
Check if the given request URL matches this route pattern.
|
|
"""
|
|
if self.hostname is None:
|
|
server_hostname = request.environ["HOSTNAME"]
|
|
else:
|
|
server_hostname = self.hostname
|
|
server_port = request.environ["SERVER_PORT"]
|
|
|
|
if self.strict_hostname and request.hostname != server_hostname:
|
|
return None
|
|
if self.strict_port and request.port is not None:
|
|
if request.port != server_port:
|
|
return None
|
|
if self.scheme and self.scheme != request.scheme:
|
|
return None
|
|
|
|
if self.strict_trailing_slash:
|
|
request_path = request.path
|
|
else:
|
|
request_path = request.path.rstrip("/")
|
|
|
|
return re.fullmatch(self.path, request_path)
|
|
|
|
|
|
class RateLimiter:
|
|
"""
|
|
A class that can be used to apply rate-limiting to endpoints.
|
|
|
|
Rates are defined as human-readable strings, e.g.
|
|
|
|
"5/s (5 requests per-second)
|
|
"10/5m" (10 requests per-5 minutes)
|
|
"100/2h" (100 requests per-2 hours)
|
|
"1000/d" (1k requests per-day)
|
|
"""
|
|
|
|
RE = re.compile("(?P<number>[0-9]+)/(?P<period>[0-9]+)?(?P<unit>[smhd])")
|
|
|
|
number: int
|
|
period: int
|
|
next_timestamp: float
|
|
rate_counter: typing.Dict[typing.Any, int]
|
|
|
|
def __init__(self, rate: str) -> None:
|
|
match = self.RE.fullmatch(rate)
|
|
if not match:
|
|
raise ValueError(f"Invalid rate format: {rate}")
|
|
|
|
rate_data = match.groupdict()
|
|
|
|
self.number = int(rate_data["number"])
|
|
self.period = int(rate_data["period"] or 1)
|
|
if rate_data["unit"] == "m":
|
|
self.period *= 60
|
|
elif rate_data["unit"] == "h":
|
|
self.period += 60 * 60
|
|
elif rate_data["unit"] == "d":
|
|
self.period *= 60 * 60 * 24
|
|
|
|
self.reset()
|
|
|
|
def reset(self) -> None:
|
|
self.next_timestamp = time.time() + self.period
|
|
self.rate_counter = defaultdict(int)
|
|
|
|
def get_key(self, request: Request) -> typing.Any:
|
|
"""
|
|
Rate limit based on the client's IP-address.
|
|
"""
|
|
return request.environ["REMOTE_ADDR"]
|
|
|
|
def check(self, request: Request) -> typing.Optional[Response]:
|
|
"""
|
|
Check if the given request should be rate limited.
|
|
|
|
This method will return a failure response if the request should be
|
|
rate limited.
|
|
"""
|
|
time_left = self.next_timestamp - time.time()
|
|
if time_left < 0:
|
|
self.reset()
|
|
|
|
key = self.get_key(request)
|
|
if key is not None:
|
|
self.rate_counter[key] += 1
|
|
if self.rate_counter[key] > self.number:
|
|
msg = f"Rate limit exceeded, wait {time_left:.0f} seconds."
|
|
return Response(Status.SLOW_DOWN, msg)
|
|
|
|
return None
|
|
|
|
def apply(self, wrapped_func: RouteHandler) -> RouteHandler:
|
|
"""
|
|
Decorator to apply rate limiting to an individual application route.
|
|
|
|
Usage:
|
|
rate_limiter = RateLimiter("10/m")
|
|
|
|
@app.route("/endpoint")
|
|
@rate_limiter.apply
|
|
def my_endpoint(request):
|
|
return Response(Status.SUCCESS, "text/gemini", "hello world!")
|
|
"""
|
|
|
|
def wrapper(request: Request, **kwargs: typing.Any) -> Response:
|
|
response = self.check(request)
|
|
if response:
|
|
return response
|
|
return wrapped_func(request, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
|
|
class JetforceApplication:
|
|
"""
|
|
Base Jetforce application class with primitive URL routing.
|
|
|
|
This is a base class for writing jetforce server applications. It doesn't do
|
|
anything on its own, but it does provide a convenient interface to define
|
|
custom server endpoints using route decorators. If you want to utilize
|
|
jetforce as a library and write your own server in python, this is the class
|
|
that you want to extend. The examples/ directory contains some examples of
|
|
how to accomplish this.
|
|
"""
|
|
|
|
rate_limiter: typing.Optional[RateLimiter]
|
|
routes: typing.List[typing.Tuple[RoutePattern, RouteHandler]]
|
|
|
|
request_class: typing.Type[Request] = Request
|
|
|
|
def __init__(self, rate_limiter: typing.Optional[RateLimiter] = None):
|
|
self.rate_limiter = rate_limiter
|
|
self.routes = []
|
|
|
|
def __call__(
|
|
self, environ: EnvironDict, send_status: WriteStatusCallable
|
|
) -> ApplicationResponse:
|
|
try:
|
|
request = self.request_class(environ)
|
|
except Exception:
|
|
send_status(Status.BAD_REQUEST, "Invalid URL")
|
|
return
|
|
|
|
if self.rate_limiter:
|
|
response = self.rate_limiter.check(request)
|
|
if response:
|
|
send_status(response.status, response.meta)
|
|
return
|
|
|
|
for route_pattern, callback in self.routes[::-1]:
|
|
match = route_pattern.match(request)
|
|
if match:
|
|
callback_kwargs = match.groupdict()
|
|
break
|
|
else:
|
|
callback = self.default_callback
|
|
callback_kwargs = {}
|
|
|
|
response = callback(request, **callback_kwargs)
|
|
send_status(response.status, response.meta)
|
|
|
|
if isinstance(response.body, (bytes, str, Deferred)):
|
|
yield response.body
|
|
elif response.body:
|
|
yield from response.body
|
|
|
|
def route(
|
|
self,
|
|
path: str = ".*",
|
|
scheme: str = "gemini",
|
|
hostname: typing.Optional[str] = None,
|
|
strict_hostname: bool = True,
|
|
strict_trailing_slash: bool = False,
|
|
) -> typing.Callable[[RouteHandler], RouteHandler]:
|
|
"""
|
|
Decorator for binding a function to a route based on the URL path.
|
|
|
|
app = JetforceApplication()
|
|
|
|
@app.route('/my-path')
|
|
def my_path(request):
|
|
return Response(Status.SUCCESS, 'text/plain', 'Hello world!')
|
|
"""
|
|
route_pattern = RoutePattern(
|
|
path, scheme, hostname, strict_hostname, strict_trailing_slash
|
|
)
|
|
|
|
def wrap(func: RouteHandler) -> RouteHandler:
|
|
self.routes.append((route_pattern, func))
|
|
return func
|
|
|
|
return wrap
|
|
|
|
def default_callback(self, request: Request, **_: typing.Any) -> Response:
|
|
"""
|
|
Set the error response based on the URL type.
|
|
"""
|
|
return Response(Status.PERMANENT_FAILURE, "Not Found")
|