Rate limiting proof of concept

This commit is contained in:
Michael Lazar 2020-07-25 22:57:51 -04:00
parent afa210db4f
commit d4412956ad
9 changed files with 125 additions and 8 deletions

2
.isort.cfg Normal file
View File

@ -0,0 +1,2 @@
[isort]
profile=black

View File

@ -7,6 +7,6 @@ repos:
hooks:
- id: black
- repo: https://github.com/pre-commit/mirrors-isort
rev: v4.3.21
rev: v5.1.4
hooks:
- id: isort
- id: isort

View File

@ -17,9 +17,10 @@ streaming.
from collections import deque
from datetime import datetime
from jetforce import GeminiServer, JetforceApplication, Response, Status
from twisted.internet.defer import AlreadyCalledError, Deferred
from jetforce import GeminiServer, JetforceApplication, Response, Status
class MessageQueue:
def __init__(self, filename):

View File

@ -9,11 +9,12 @@ loading the entire response into memory at once.
"""
import time
from jetforce import GeminiServer, JetforceApplication, Response, Status
from twisted.internet import reactor
from twisted.internet.task import deferLater
from twisted.internet.threads import deferToThread
from jetforce import GeminiServer, JetforceApplication, Response, Status
def blocking_counter():
"""

38
examples/rate_limit.py Normal file
View File

@ -0,0 +1,38 @@
#!/usr/local/env python3
"""
This example shows how you can implement advanced rate limiting schemes.
"""
from jetforce import GeminiServer, JetforceApplication, RateLimiter, Response, Status
app = JetforceApplication()
INDEX_PAGE = """\
# Rate Limiting Demo
=>/short short rate limiter (5/30s)
=>/long long rate limiter (60/5m)
"""
@app.route("", strict_trailing_slash=False)
def index(request):
return Response(Status.SUCCESS, "text/gemini", INDEX_PAGE)
@app.route("/short")
@RateLimiter("5/30s")
def short(request):
# Maximum of 5 requests per 30 seconds
return Response(Status.SUCCESS, "text/gemini", "Request was successful")
@app.route("/long")
@RateLimiter("60/5m")
def long(request):
# Maximum of 60 requests per 5 minutes
return Response(Status.SUCCESS, "text/gemini", "Request was successful")
if __name__ == "__main__":
server = GeminiServer(app, host="127.0.0.1", hostname="localhost")
server.run()

View File

@ -2,7 +2,14 @@
isort:skip_file
"""
from .__version__ import __version__
from .app.base import JetforceApplication, Request, Response, RoutePattern, Status
from .app.base import (
JetforceApplication,
Request,
Response,
RoutePattern,
Status,
RateLimiter,
)
from .app.static import StaticDirectoryApplication
from .app.composite import CompositeApplication
from .protocol import GeminiProtocol

View File

@ -91,13 +91,18 @@ group.add_argument(
metavar="FILE",
dest="index_file",
)
group.add_argument(
"--default-lang",
help="A lang parameter that will be indicated in the response meta",
default=None,
dest="default_lang",
)
group.add_argument(
"--rate-limit",
help="An IP rate limit string, e.g. '60/5m' (60 requests per 5 minutes)",
default=None,
dest="rate_limit",
)
def main():
@ -107,6 +112,7 @@ def main():
index_file=args.index_file,
cgi_directory=args.cgi_directory,
default_lang=args.default_lang,
rate_limit=args.rate_limit,
)
server = GeminiServer(
app=app,

View File

@ -1,6 +1,8 @@
import dataclasses
import re
import time
import typing
from collections import defaultdict
from urllib.parse import unquote, urlparse
from twisted.internet.defer import Deferred
@ -196,3 +198,50 @@ class JetforceApplication:
Set the error response based on the URL type.
"""
return Response(Status.PERMANENT_FAILURE, "Not Found")
class RateLimiter:
RE = re.compile("(?P<number>[0-9]+)/(?P<period>[0-9]+)?(?P<unit>[smhd])")
def __init__(self, rate: str) -> None:
match = self.RE.fullmatch(rate)
if not match:
raise ValueError(f"Invalid rate format: {rate}")
rate_data = match.groupdict()
self.number = int(rate_data["number"])
self.period = int(rate_data["period"] or 1)
if rate_data["unit"] == "m":
self.period *= 60
elif rate_data["unit"] == "h":
self.period += 60 * 60
elif rate_data["unit"] == "d":
self.period *= 60 * 60 * 24
self.reset()
def reset(self) -> None:
self.timestamp = time.time() + self.period
self.counter = defaultdict(int)
def get_key(self, request: Request) -> typing.Optional[str]:
return request.environ["REMOTE_ADDR"]
def __call__(self, func: typing.Callable) -> typing.Callable:
def handler(request, **kwargs) -> Response:
time_left = self.timestamp - time.time()
if time_left < 0:
self.reset()
rate_key = self.get_key(request)
if rate_key is not None:
self.counter[rate_key] += 1
if self.counter[rate_key] > self.number:
msg = f"Rate limit exceeded, wait {time_left:.0f} seconds."
return Response(Status.SLOW_DOWN, msg)
return func(request, **kwargs)
return handler

View File

@ -6,7 +6,14 @@ import subprocess
import typing
import urllib.parse
from .base import JetforceApplication, Request, Response, RoutePattern, Status
from .base import (
JetforceApplication,
RateLimiter,
Request,
Response,
RoutePattern,
Status,
)
class StaticDirectoryApplication(JetforceApplication):
@ -32,9 +39,15 @@ class StaticDirectoryApplication(JetforceApplication):
index_file: str = "index.gmi",
cgi_directory: str = "cgi-bin",
default_lang: typing.Optional[str] = None,
rate_limit: typing.Optional[str] = None,
):
super().__init__()
self.routes.append((RoutePattern(), self.serve_static_file))
request_method = self.serve_static_file
if rate_limit is not None:
request_method = RateLimiter(rate_limit)(request_method)
self.routes.append((RoutePattern(), request_method))
self.root = pathlib.Path(root_directory).resolve(strict=True)
self.cgi_directory = cgi_directory.strip("/") + "/"