diff --git a/.isort.cfg b/.isort.cfg new file mode 100644 index 0000000..9caba16 --- /dev/null +++ b/.isort.cfg @@ -0,0 +1,2 @@ +[isort] +profile=black diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ec7af1a..be145a7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,6 +7,6 @@ repos: hooks: - id: black - repo: https://github.com/pre-commit/mirrors-isort - rev: v4.3.21 + rev: v5.1.4 hooks: - - id: isort \ No newline at end of file + - id: isort diff --git a/examples/chatroom.py b/examples/chatroom.py index f69f455..bd81e26 100644 --- a/examples/chatroom.py +++ b/examples/chatroom.py @@ -17,9 +17,10 @@ streaming. from collections import deque from datetime import datetime -from jetforce import GeminiServer, JetforceApplication, Response, Status from twisted.internet.defer import AlreadyCalledError, Deferred +from jetforce import GeminiServer, JetforceApplication, Response, Status + class MessageQueue: def __init__(self, filename): diff --git a/examples/counter.py b/examples/counter.py index f60b89d..16ca794 100644 --- a/examples/counter.py +++ b/examples/counter.py @@ -9,11 +9,12 @@ loading the entire response into memory at once. """ import time -from jetforce import GeminiServer, JetforceApplication, Response, Status from twisted.internet import reactor from twisted.internet.task import deferLater from twisted.internet.threads import deferToThread +from jetforce import GeminiServer, JetforceApplication, Response, Status + def blocking_counter(): """ diff --git a/examples/rate_limit.py b/examples/rate_limit.py new file mode 100644 index 0000000..ce44ab6 --- /dev/null +++ b/examples/rate_limit.py @@ -0,0 +1,38 @@ +#!/usr/local/env python3 +""" +This example shows how you can implement advanced rate limiting schemes. +""" +from jetforce import GeminiServer, JetforceApplication, RateLimiter, Response, Status + +app = JetforceApplication() + +INDEX_PAGE = """\ +# Rate Limiting Demo + +=>/short short rate limiter (5/30s) +=>/long long rate limiter (60/5m) +""" + + +@app.route("", strict_trailing_slash=False) +def index(request): + return Response(Status.SUCCESS, "text/gemini", INDEX_PAGE) + + +@app.route("/short") +@RateLimiter("5/30s") +def short(request): + # Maximum of 5 requests per 30 seconds + return Response(Status.SUCCESS, "text/gemini", "Request was successful") + + +@app.route("/long") +@RateLimiter("60/5m") +def long(request): + # Maximum of 60 requests per 5 minutes + return Response(Status.SUCCESS, "text/gemini", "Request was successful") + + +if __name__ == "__main__": + server = GeminiServer(app, host="127.0.0.1", hostname="localhost") + server.run() diff --git a/jetforce/__init__.py b/jetforce/__init__.py index 072636d..16d0a3c 100644 --- a/jetforce/__init__.py +++ b/jetforce/__init__.py @@ -2,7 +2,14 @@ isort:skip_file """ from .__version__ import __version__ -from .app.base import JetforceApplication, Request, Response, RoutePattern, Status +from .app.base import ( + JetforceApplication, + Request, + Response, + RoutePattern, + Status, + RateLimiter, +) from .app.static import StaticDirectoryApplication from .app.composite import CompositeApplication from .protocol import GeminiProtocol diff --git a/jetforce/__main__.py b/jetforce/__main__.py index 18f74b1..9d1be3e 100644 --- a/jetforce/__main__.py +++ b/jetforce/__main__.py @@ -91,13 +91,18 @@ group.add_argument( metavar="FILE", dest="index_file", ) - group.add_argument( "--default-lang", help="A lang parameter that will be indicated in the response meta", default=None, dest="default_lang", ) +group.add_argument( + "--rate-limit", + help="An IP rate limit string, e.g. '60/5m' (60 requests per 5 minutes)", + default=None, + dest="rate_limit", +) def main(): @@ -107,6 +112,7 @@ def main(): index_file=args.index_file, cgi_directory=args.cgi_directory, default_lang=args.default_lang, + rate_limit=args.rate_limit, ) server = GeminiServer( app=app, diff --git a/jetforce/app/base.py b/jetforce/app/base.py index 7f83686..2913fdd 100644 --- a/jetforce/app/base.py +++ b/jetforce/app/base.py @@ -1,6 +1,8 @@ import dataclasses import re +import time import typing +from collections import defaultdict from urllib.parse import unquote, urlparse from twisted.internet.defer import Deferred @@ -196,3 +198,50 @@ class JetforceApplication: Set the error response based on the URL type. """ return Response(Status.PERMANENT_FAILURE, "Not Found") + + +class RateLimiter: + + RE = re.compile("(?P[0-9]+)/(?P[0-9]+)?(?P[smhd])") + + def __init__(self, rate: str) -> None: + match = self.RE.fullmatch(rate) + if not match: + raise ValueError(f"Invalid rate format: {rate}") + + rate_data = match.groupdict() + + self.number = int(rate_data["number"]) + self.period = int(rate_data["period"] or 1) + if rate_data["unit"] == "m": + self.period *= 60 + elif rate_data["unit"] == "h": + self.period += 60 * 60 + elif rate_data["unit"] == "d": + self.period *= 60 * 60 * 24 + + self.reset() + + def reset(self) -> None: + self.timestamp = time.time() + self.period + self.counter = defaultdict(int) + + def get_key(self, request: Request) -> typing.Optional[str]: + return request.environ["REMOTE_ADDR"] + + def __call__(self, func: typing.Callable) -> typing.Callable: + def handler(request, **kwargs) -> Response: + time_left = self.timestamp - time.time() + if time_left < 0: + self.reset() + + rate_key = self.get_key(request) + if rate_key is not None: + self.counter[rate_key] += 1 + if self.counter[rate_key] > self.number: + msg = f"Rate limit exceeded, wait {time_left:.0f} seconds." + return Response(Status.SLOW_DOWN, msg) + + return func(request, **kwargs) + + return handler diff --git a/jetforce/app/static.py b/jetforce/app/static.py index 5043e5d..104ae9e 100644 --- a/jetforce/app/static.py +++ b/jetforce/app/static.py @@ -6,7 +6,14 @@ import subprocess import typing import urllib.parse -from .base import JetforceApplication, Request, Response, RoutePattern, Status +from .base import ( + JetforceApplication, + RateLimiter, + Request, + Response, + RoutePattern, + Status, +) class StaticDirectoryApplication(JetforceApplication): @@ -32,9 +39,15 @@ class StaticDirectoryApplication(JetforceApplication): index_file: str = "index.gmi", cgi_directory: str = "cgi-bin", default_lang: typing.Optional[str] = None, + rate_limit: typing.Optional[str] = None, ): super().__init__() - self.routes.append((RoutePattern(), self.serve_static_file)) + + request_method = self.serve_static_file + if rate_limit is not None: + request_method = RateLimiter(rate_limit)(request_method) + + self.routes.append((RoutePattern(), request_method)) self.root = pathlib.Path(root_directory).resolve(strict=True) self.cgi_directory = cgi_directory.strip("/") + "/"