Restructure rate limiting

This commit is contained in:
Michael Lazar 2020-07-27 00:02:51 -04:00
parent d4412956ad
commit 1e5be2b45a
6 changed files with 154 additions and 92 deletions

View File

@ -1,16 +1,29 @@
# Jetforce Changelog
### Unreleased
### v0.6.0 (Unreleased)
#### Bugfixes
- File chunking has been optimized for streaming large static files.
- Server access logs are now redirected to ``stdout`` instead of ``stderr``.
This is intended to make it easier to use a log manager tool to split them
out from other server messages like startup information and error tracebacks.
- The default mimetype for unknown file extensions will now be sent as
"application/octet-stream" instead of "text/plain". The expectation is that
it would be safer for a client to download an unknown file rather than
attempting to display it inline as text.
#### Features
- The static file server now has a ``--rate-limit`` flag that can be used
to define per-IP address rate limiting for requests. Requests that exceed
the specified rate will receive a 44 SLOW DOWN error response.
- Server access logs are now redirected to ``stdout`` instead of ``stderr``.
This is intended to make it easier to use a log manager tool to split them
out from other server messages like startup information and error tracebacks.
- File chunking has been optimized for streaming large static files.
#### Examples
- Added a new example that demonstrates how to use the new ``RateLimiter``
class (examples/rate_limit.py).
### v0.5.0 (2020-07-14)
#### Spec Changes

View File

@ -59,11 +59,10 @@ $ /opt/jetforce/venv/bin/jetforce
Use the ``--help`` flag to view command-line options:
```bash
$ jetforce --help
usage: jetforce [-h] [-V] [--host HOST] [--port PORT] [--hostname HOSTNAME]
[--tls-certfile FILE] [--tls-keyfile FILE] [--tls-cafile FILE]
[--tls-capath DIR] [--dir DIR] [--cgi-dir DIR]
[--index-file FILE]
[--tls-capath DIR] [--dir DIR] [--cgi-dir DIR] [--index-file FILE]
[--default-lang DEFAULT_LANG] [--rate-limit RATE_LIMIT]
An Experimental Gemini Protocol Server
@ -78,17 +77,21 @@ server configuration:
--tls-certfile FILE Server TLS certificate file (default: None)
--tls-keyfile FILE Server TLS private key file (default: None)
--tls-cafile FILE A CA file to use for validating clients (default: None)
--tls-capath DIR A directory containing CA files for validating clients
(default: None)
--tls-capath DIR A directory containing CA files for validating clients (default:
None)
fileserver configuration:
--dir DIR Root directory on the filesystem to serve (default:
/var/gemini)
--cgi-dir DIR CGI script directory, relative to the server's root
directory (default: cgi-bin)
--index-file FILE If a directory contains a file with this name, that
file will be served instead of auto-generating an index
page (default: index.gmi)
--dir DIR Root directory on the filesystem to serve (default: /var/gemini)
--cgi-dir DIR CGI script directory, relative to the server's root directory
(default: cgi-bin)
--index-file FILE If a directory contains a file with this name, that file will be
served instead of auto-generating an index page (default: index.gmi)
--default-lang DEFAULT_LANG
A lang parameter that will be indicated in the response meta
(default: None)
--rate-limit RATE_LIMIT
Enable IP rate limiting, e.g. '60/5m' (60 requests per 5 minutes)
(default: None)
```
### Setting the ``hostname``

View File

@ -1,10 +1,17 @@
#!/usr/local/env python3
"""
This example shows how you can implement advanced rate limiting schemes.
This example shows how you can implement rate limiting on a per-endpoint basis.
"""
from jetforce import GeminiServer, JetforceApplication, RateLimiter, Response, Status
app = JetforceApplication()
# Apply a global rate limiter that will be applied to all requests
global_rate_limiter = RateLimiter("100/m")
app = JetforceApplication(rate_limiter=global_rate_limiter)
# Setup some custom rate limiting for specific endpoints
short_rate_limiter = RateLimiter("5/30s")
long_rate_limiter = RateLimiter("60/5m")
INDEX_PAGE = """\
# Rate Limiting Demo
@ -20,16 +27,14 @@ def index(request):
@app.route("/short")
@RateLimiter("5/30s")
@short_rate_limiter.apply
def short(request):
# Maximum of 5 requests per 30 seconds
return Response(Status.SUCCESS, "text/gemini", "Request was successful")
@app.route("/long")
@RateLimiter("60/5m")
@long_rate_limiter.apply
def long(request):
# Maximum of 60 requests per 5 minutes
return Response(Status.SUCCESS, "text/gemini", "Request was successful")

View File

@ -9,6 +9,7 @@ import argparse
import sys
from .__version__ import __version__
from .app.base import RateLimiter
from .app.static import StaticDirectoryApplication
from .server import GeminiServer
@ -99,7 +100,7 @@ group.add_argument(
)
group.add_argument(
"--rate-limit",
help="An IP rate limit string, e.g. '60/5m' (60 requests per 5 minutes)",
help="Enable IP rate limiting, e.g. '60/5m' (60 requests per 5 minutes)",
default=None,
dest="rate_limit",
)
@ -107,12 +108,13 @@ group.add_argument(
def main():
args = parser.parse_args()
rate_limiter = RateLimiter(args.rate_limit) if args.rate_limit else None
app = StaticDirectoryApplication(
root_directory=args.root_directory,
index_file=args.index_file,
cgi_directory=args.cgi_directory,
default_lang=args.default_lang,
rate_limit=args.rate_limit,
rate_limiter=rate_limiter,
)
server = GeminiServer(
app=app,

View File

@ -123,6 +123,91 @@ class RoutePattern:
return re.fullmatch(self.path, request_path)
RouteHandler = typing.Callable[..., Response]
class RateLimiter:
"""
A class that can be used to apply rate-limiting to endpoints.
Rates are defined as human-readable strings, e.g.
"5/s (5 requests per-second)
"10/5m" (10 requests per-5 minutes)
"100/2h" (100 requests per-2 hours)
"1000/d" (1k requests per-day)
"""
RE = re.compile("(?P<number>[0-9]+)/(?P<period>[0-9]+)?(?P<unit>[smhd])")
def __init__(self, rate: str) -> None:
match = self.RE.fullmatch(rate)
if not match:
raise ValueError(f"Invalid rate format: {rate}")
rate_data = match.groupdict()
self.number = int(rate_data["number"])
self.period = int(rate_data["period"] or 1)
if rate_data["unit"] == "m":
self.period *= 60
elif rate_data["unit"] == "h":
self.period += 60 * 60
elif rate_data["unit"] == "d":
self.period *= 60 * 60 * 24
self.reset()
def reset(self) -> None:
self.next_timestamp = time.time() + self.period
self.rate_counter = defaultdict(int)
def get_key(self, request: Request) -> typing.Optional[str]:
"""
Rate limit based on the client's IP-address.
"""
return request.environ["REMOTE_ADDR"]
def check(self, request: Request) -> typing.Optional[Response]:
"""
Check if the given request should be rate limited.
This method will return a failure response if the request should be
rate limited.
"""
time_left = self.next_timestamp - time.time()
if time_left < 0:
self.reset()
key = self.get_key(request)
if key is not None:
self.rate_counter[key] += 1
if self.rate_counter[key] > self.number:
msg = f"Rate limit exceeded, wait {time_left:.0f} seconds."
return Response(Status.SLOW_DOWN, msg)
def apply(self, wrapped_func: RouteHandler) -> RouteHandler:
"""
Decorator to apply rate limiting to an individual application route.
Usage:
rate_limiter = RateLimiter("10/m")
@app.route("/endpoint")
@rate_limiter.apply
def my_endpoint(request):
return Response(Status.SUCCESS, "text/gemini", "hello world!")
"""
def wrapper(request: Request, **kwargs) -> Response:
response = self.check(request)
if response:
return response
return wrapped_func(request, **kwargs)
return wrapper
class JetforceApplication:
"""
Base Jetforce application class with primitive URL routing.
@ -135,10 +220,9 @@ class JetforceApplication:
how to accomplish this.
"""
def __init__(self):
self.routes: typing.List[
typing.Tuple[RoutePattern, typing.Callable[[Request, ...], Response]]
] = []
def __init__(self, rate_limiter: typing.Optional[RateLimiter] = None):
self.rate_limiter = rate_limiter
self.routes: typing.List[typing.Tuple[RoutePattern, RouteHandler]] = []
def __call__(
self, environ: dict, send_status: typing.Callable
@ -149,6 +233,12 @@ class JetforceApplication:
send_status(Status.BAD_REQUEST, "Invalid URL")
return
if self.rate_limiter:
response = self.rate_limiter.check(request)
if response:
send_status(response.status, response.meta)
return
for route_pattern, callback in self.routes[::-1]:
match = route_pattern.match(request)
if route_pattern.match(request):
@ -187,7 +277,7 @@ class JetforceApplication:
path, scheme, hostname, strict_hostname, strict_trailing_slash
)
def wrap(func: typing.Callable) -> typing.Callable:
def wrap(func: RouteHandler) -> RouteHandler:
self.routes.append((route_pattern, func))
return func
@ -198,50 +288,3 @@ class JetforceApplication:
Set the error response based on the URL type.
"""
return Response(Status.PERMANENT_FAILURE, "Not Found")
class RateLimiter:
RE = re.compile("(?P<number>[0-9]+)/(?P<period>[0-9]+)?(?P<unit>[smhd])")
def __init__(self, rate: str) -> None:
match = self.RE.fullmatch(rate)
if not match:
raise ValueError(f"Invalid rate format: {rate}")
rate_data = match.groupdict()
self.number = int(rate_data["number"])
self.period = int(rate_data["period"] or 1)
if rate_data["unit"] == "m":
self.period *= 60
elif rate_data["unit"] == "h":
self.period += 60 * 60
elif rate_data["unit"] == "d":
self.period *= 60 * 60 * 24
self.reset()
def reset(self) -> None:
self.timestamp = time.time() + self.period
self.counter = defaultdict(int)
def get_key(self, request: Request) -> typing.Optional[str]:
return request.environ["REMOTE_ADDR"]
def __call__(self, func: typing.Callable) -> typing.Callable:
def handler(request, **kwargs) -> Response:
time_left = self.timestamp - time.time()
if time_left < 0:
self.reset()
rate_key = self.get_key(request)
if rate_key is not None:
self.counter[rate_key] += 1
if self.counter[rate_key] > self.number:
msg = f"Rate limit exceeded, wait {time_left:.0f} seconds."
return Response(Status.SLOW_DOWN, msg)
return func(request, **kwargs)
return handler

View File

@ -39,15 +39,11 @@ class StaticDirectoryApplication(JetforceApplication):
index_file: str = "index.gmi",
cgi_directory: str = "cgi-bin",
default_lang: typing.Optional[str] = None,
rate_limit: typing.Optional[str] = None,
rate_limiter: typing.Optional[RateLimiter] = None,
):
super().__init__()
super().__init__(rate_limiter=rate_limiter)
request_method = self.serve_static_file
if rate_limit is not None:
request_method = RateLimiter(rate_limit)(request_method)
self.routes.append((RoutePattern(), request_method))
self.routes.append((RoutePattern(), self.serve_static_file))
self.root = pathlib.Path(root_directory).resolve(strict=True)
self.cgi_directory = cgi_directory.strip("/") + "/"