diff --git a/CHANGELOG.md b/CHANGELOG.md index ab598d8..22bc02c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,16 +1,29 @@ # Jetforce Changelog -### Unreleased +### v0.6.0 (Unreleased) + +#### Bugfixes -- File chunking has been optimized for streaming large static files. -- Server access logs are now redirected to ``stdout`` instead of ``stderr``. - This is intended to make it easier to use a log manager tool to split them - out from other server messages like startup information and error tracebacks. - The default mimetype for unknown file extensions will now be sent as "application/octet-stream" instead of "text/plain". The expectation is that it would be safer for a client to download an unknown file rather than attempting to display it inline as text. +#### Features + +- The static file server now has a ``--rate-limit`` flag that can be used + to define per-IP address rate limiting for requests. Requests that exceed + the specified rate will receive a 44 SLOW DOWN error response. +- Server access logs are now redirected to ``stdout`` instead of ``stderr``. + This is intended to make it easier to use a log manager tool to split them + out from other server messages like startup information and error tracebacks. +- File chunking has been optimized for streaming large static files. + +#### Examples + +- Added a new example that demonstrates how to use the new ``RateLimiter`` + class (examples/rate_limit.py). + ### v0.5.0 (2020-07-14) #### Spec Changes diff --git a/README.md b/README.md index 8704cbe..51a83ba 100644 --- a/README.md +++ b/README.md @@ -59,36 +59,39 @@ $ /opt/jetforce/venv/bin/jetforce Use the ``--help`` flag to view command-line options: ```bash -$ jetforce --help usage: jetforce [-h] [-V] [--host HOST] [--port PORT] [--hostname HOSTNAME] [--tls-certfile FILE] [--tls-keyfile FILE] [--tls-cafile FILE] - [--tls-capath DIR] [--dir DIR] [--cgi-dir DIR] - [--index-file FILE] + [--tls-capath DIR] [--dir DIR] [--cgi-dir DIR] [--index-file FILE] + [--default-lang DEFAULT_LANG] [--rate-limit RATE_LIMIT] An Experimental Gemini Protocol Server optional arguments: - -h, --help show this help message and exit - -V, --version show program's version number and exit + -h, --help show this help message and exit + -V, --version show program's version number and exit server configuration: - --host HOST Server address to bind to (default: 127.0.0.1) - --port PORT Server port to bind to (default: 1965) - --hostname HOSTNAME Server hostname (default: localhost) - --tls-certfile FILE Server TLS certificate file (default: None) - --tls-keyfile FILE Server TLS private key file (default: None) - --tls-cafile FILE A CA file to use for validating clients (default: None) - --tls-capath DIR A directory containing CA files for validating clients - (default: None) + --host HOST Server address to bind to (default: 127.0.0.1) + --port PORT Server port to bind to (default: 1965) + --hostname HOSTNAME Server hostname (default: localhost) + --tls-certfile FILE Server TLS certificate file (default: None) + --tls-keyfile FILE Server TLS private key file (default: None) + --tls-cafile FILE A CA file to use for validating clients (default: None) + --tls-capath DIR A directory containing CA files for validating clients (default: + None) fileserver configuration: - --dir DIR Root directory on the filesystem to serve (default: - /var/gemini) - --cgi-dir DIR CGI script directory, relative to the server's root - directory (default: cgi-bin) - --index-file FILE If a directory contains a file with this name, that - file will be served instead of auto-generating an index - page (default: index.gmi) + --dir DIR Root directory on the filesystem to serve (default: /var/gemini) + --cgi-dir DIR CGI script directory, relative to the server's root directory + (default: cgi-bin) + --index-file FILE If a directory contains a file with this name, that file will be + served instead of auto-generating an index page (default: index.gmi) + --default-lang DEFAULT_LANG + A lang parameter that will be indicated in the response meta + (default: None) + --rate-limit RATE_LIMIT + Enable IP rate limiting, e.g. '60/5m' (60 requests per 5 minutes) + (default: None) ``` ### Setting the ``hostname`` diff --git a/examples/rate_limit.py b/examples/rate_limit.py index ce44ab6..09e1146 100644 --- a/examples/rate_limit.py +++ b/examples/rate_limit.py @@ -1,10 +1,17 @@ #!/usr/local/env python3 """ -This example shows how you can implement advanced rate limiting schemes. +This example shows how you can implement rate limiting on a per-endpoint basis. """ from jetforce import GeminiServer, JetforceApplication, RateLimiter, Response, Status -app = JetforceApplication() +# Apply a global rate limiter that will be applied to all requests +global_rate_limiter = RateLimiter("100/m") +app = JetforceApplication(rate_limiter=global_rate_limiter) + +# Setup some custom rate limiting for specific endpoints +short_rate_limiter = RateLimiter("5/30s") +long_rate_limiter = RateLimiter("60/5m") + INDEX_PAGE = """\ # Rate Limiting Demo @@ -20,16 +27,14 @@ def index(request): @app.route("/short") -@RateLimiter("5/30s") +@short_rate_limiter.apply def short(request): - # Maximum of 5 requests per 30 seconds return Response(Status.SUCCESS, "text/gemini", "Request was successful") @app.route("/long") -@RateLimiter("60/5m") +@long_rate_limiter.apply def long(request): - # Maximum of 60 requests per 5 minutes return Response(Status.SUCCESS, "text/gemini", "Request was successful") diff --git a/jetforce/__main__.py b/jetforce/__main__.py index 9d1be3e..7fea58c 100644 --- a/jetforce/__main__.py +++ b/jetforce/__main__.py @@ -9,6 +9,7 @@ import argparse import sys from .__version__ import __version__ +from .app.base import RateLimiter from .app.static import StaticDirectoryApplication from .server import GeminiServer @@ -99,7 +100,7 @@ group.add_argument( ) group.add_argument( "--rate-limit", - help="An IP rate limit string, e.g. '60/5m' (60 requests per 5 minutes)", + help="Enable IP rate limiting, e.g. '60/5m' (60 requests per 5 minutes)", default=None, dest="rate_limit", ) @@ -107,12 +108,13 @@ group.add_argument( def main(): args = parser.parse_args() + rate_limiter = RateLimiter(args.rate_limit) if args.rate_limit else None app = StaticDirectoryApplication( root_directory=args.root_directory, index_file=args.index_file, cgi_directory=args.cgi_directory, default_lang=args.default_lang, - rate_limit=args.rate_limit, + rate_limiter=rate_limiter, ) server = GeminiServer( app=app, diff --git a/jetforce/app/base.py b/jetforce/app/base.py index 2913fdd..fbc5571 100644 --- a/jetforce/app/base.py +++ b/jetforce/app/base.py @@ -123,6 +123,91 @@ class RoutePattern: return re.fullmatch(self.path, request_path) +RouteHandler = typing.Callable[..., Response] + + +class RateLimiter: + """ + A class that can be used to apply rate-limiting to endpoints. + + Rates are defined as human-readable strings, e.g. + + "5/s (5 requests per-second) + "10/5m" (10 requests per-5 minutes) + "100/2h" (100 requests per-2 hours) + "1000/d" (1k requests per-day) + """ + + RE = re.compile("(?P[0-9]+)/(?P[0-9]+)?(?P[smhd])") + + def __init__(self, rate: str) -> None: + match = self.RE.fullmatch(rate) + if not match: + raise ValueError(f"Invalid rate format: {rate}") + + rate_data = match.groupdict() + + self.number = int(rate_data["number"]) + self.period = int(rate_data["period"] or 1) + if rate_data["unit"] == "m": + self.period *= 60 + elif rate_data["unit"] == "h": + self.period += 60 * 60 + elif rate_data["unit"] == "d": + self.period *= 60 * 60 * 24 + + self.reset() + + def reset(self) -> None: + self.next_timestamp = time.time() + self.period + self.rate_counter = defaultdict(int) + + def get_key(self, request: Request) -> typing.Optional[str]: + """ + Rate limit based on the client's IP-address. + """ + return request.environ["REMOTE_ADDR"] + + def check(self, request: Request) -> typing.Optional[Response]: + """ + Check if the given request should be rate limited. + + This method will return a failure response if the request should be + rate limited. + """ + time_left = self.next_timestamp - time.time() + if time_left < 0: + self.reset() + + key = self.get_key(request) + if key is not None: + self.rate_counter[key] += 1 + if self.rate_counter[key] > self.number: + msg = f"Rate limit exceeded, wait {time_left:.0f} seconds." + return Response(Status.SLOW_DOWN, msg) + + def apply(self, wrapped_func: RouteHandler) -> RouteHandler: + """ + Decorator to apply rate limiting to an individual application route. + + Usage: + rate_limiter = RateLimiter("10/m") + + @app.route("/endpoint") + @rate_limiter.apply + def my_endpoint(request): + return Response(Status.SUCCESS, "text/gemini", "hello world!") + """ + + def wrapper(request: Request, **kwargs) -> Response: + response = self.check(request) + if response: + return response + return wrapped_func(request, **kwargs) + + return wrapper + + class JetforceApplication: """ Base Jetforce application class with primitive URL routing. @@ -135,10 +220,9 @@ class JetforceApplication: how to accomplish this. """ - def __init__(self): - self.routes: typing.List[ - typing.Tuple[RoutePattern, typing.Callable[[Request, ...], Response]] - ] = [] + def __init__(self, rate_limiter: typing.Optional[RateLimiter] = None): + self.rate_limiter = rate_limiter + self.routes: typing.List[typing.Tuple[RoutePattern, RouteHandler]] = [] def __call__( self, environ: dict, send_status: typing.Callable @@ -149,6 +233,12 @@ class JetforceApplication: send_status(Status.BAD_REQUEST, "Invalid URL") return + if self.rate_limiter: + response = self.rate_limiter.check(request) + if response: + send_status(response.status, response.meta) + return + for route_pattern, callback in self.routes[::-1]: match = route_pattern.match(request) if route_pattern.match(request): @@ -187,7 +277,7 @@ class JetforceApplication: path, scheme, hostname, strict_hostname, strict_trailing_slash ) - def wrap(func: typing.Callable) -> typing.Callable: + def wrap(func: RouteHandler) -> RouteHandler: self.routes.append((route_pattern, func)) return func @@ -198,50 +288,3 @@ class JetforceApplication: Set the error response based on the URL type. """ return Response(Status.PERMANENT_FAILURE, "Not Found") - - -class RateLimiter: - - RE = re.compile("(?P[0-9]+)/(?P[0-9]+)?(?P[smhd])") - - def __init__(self, rate: str) -> None: - match = self.RE.fullmatch(rate) - if not match: - raise ValueError(f"Invalid rate format: {rate}") - - rate_data = match.groupdict() - - self.number = int(rate_data["number"]) - self.period = int(rate_data["period"] or 1) - if rate_data["unit"] == "m": - self.period *= 60 - elif rate_data["unit"] == "h": - self.period += 60 * 60 - elif rate_data["unit"] == "d": - self.period *= 60 * 60 * 24 - - self.reset() - - def reset(self) -> None: - self.timestamp = time.time() + self.period - self.counter = defaultdict(int) - - def get_key(self, request: Request) -> typing.Optional[str]: - return request.environ["REMOTE_ADDR"] - - def __call__(self, func: typing.Callable) -> typing.Callable: - def handler(request, **kwargs) -> Response: - time_left = self.timestamp - time.time() - if time_left < 0: - self.reset() - - rate_key = self.get_key(request) - if rate_key is not None: - self.counter[rate_key] += 1 - if self.counter[rate_key] > self.number: - msg = f"Rate limit exceeded, wait {time_left:.0f} seconds." - return Response(Status.SLOW_DOWN, msg) - - return func(request, **kwargs) - - return handler diff --git a/jetforce/app/static.py b/jetforce/app/static.py index 104ae9e..a0f8156 100644 --- a/jetforce/app/static.py +++ b/jetforce/app/static.py @@ -39,15 +39,11 @@ class StaticDirectoryApplication(JetforceApplication): index_file: str = "index.gmi", cgi_directory: str = "cgi-bin", default_lang: typing.Optional[str] = None, - rate_limit: typing.Optional[str] = None, + rate_limiter: typing.Optional[RateLimiter] = None, ): - super().__init__() + super().__init__(rate_limiter=rate_limiter) - request_method = self.serve_static_file - if rate_limit is not None: - request_method = RateLimiter(rate_limit)(request_method) - - self.routes.append((RoutePattern(), request_method)) + self.routes.append((RoutePattern(), self.serve_static_file)) self.root = pathlib.Path(root_directory).resolve(strict=True) self.cgi_directory = cgi_directory.strip("/") + "/"