Initial commit: fastapi-traffic rate limiting library
- Core rate limiting with multiple algorithms (sliding window, token bucket, etc.) - SQLite and memory backends - Decorator and dependency injection patterns - Middleware support - Example usage files
This commit is contained in:
184
fastapi_traffic/middleware.py
Normal file
184
fastapi_traffic/middleware.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""Rate limiting middleware for Starlette/FastAPI applications."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Awaitable, Callable
|
||||
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
from fastapi_traffic.backends.memory import MemoryBackend
|
||||
from fastapi_traffic.core.algorithms import Algorithm
|
||||
from fastapi_traffic.core.config import GlobalConfig, RateLimitConfig, default_key_extractor
|
||||
from fastapi_traffic.core.limiter import RateLimiter
|
||||
from fastapi_traffic.exceptions import RateLimitExceeded
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
from fastapi_traffic.backends.base import Backend
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware for global rate limiting across all endpoints."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
*,
|
||||
limit: int = 100,
|
||||
window_size: float = 60.0,
|
||||
algorithm: Algorithm = Algorithm.SLIDING_WINDOW_COUNTER,
|
||||
backend: Backend | None = None,
|
||||
key_prefix: str = "middleware",
|
||||
include_headers: bool = True,
|
||||
error_message: str = "Rate limit exceeded. Please try again later.",
|
||||
status_code: int = 429,
|
||||
skip_on_error: bool = False,
|
||||
exempt_paths: set[str] | None = None,
|
||||
exempt_ips: set[str] | None = None,
|
||||
key_extractor: Callable[[Request], str] = default_key_extractor,
|
||||
) -> None:
|
||||
"""Initialize the rate limit middleware.
|
||||
|
||||
Args:
|
||||
app: The ASGI application.
|
||||
limit: Maximum requests per window.
|
||||
window_size: Time window in seconds.
|
||||
algorithm: Rate limiting algorithm.
|
||||
backend: Storage backend (defaults to MemoryBackend).
|
||||
key_prefix: Prefix for rate limit keys.
|
||||
include_headers: Include rate limit headers in response.
|
||||
error_message: Error message when rate limited.
|
||||
status_code: HTTP status code when rate limited.
|
||||
skip_on_error: Skip rate limiting on backend errors.
|
||||
exempt_paths: Paths to exempt from rate limiting.
|
||||
exempt_ips: IP addresses to exempt from rate limiting.
|
||||
key_extractor: Function to extract client identifier.
|
||||
"""
|
||||
super().__init__(app)
|
||||
|
||||
self._backend = backend or MemoryBackend()
|
||||
self._config = RateLimitConfig(
|
||||
limit=limit,
|
||||
window_size=window_size,
|
||||
algorithm=algorithm,
|
||||
key_prefix=key_prefix,
|
||||
key_extractor=key_extractor,
|
||||
include_headers=include_headers,
|
||||
error_message=error_message,
|
||||
status_code=status_code,
|
||||
skip_on_error=skip_on_error,
|
||||
)
|
||||
|
||||
global_config = GlobalConfig(
|
||||
backend=self._backend,
|
||||
exempt_paths=exempt_paths or set(),
|
||||
exempt_ips=exempt_ips or set(),
|
||||
)
|
||||
|
||||
self._limiter = RateLimiter(self._backend, config=global_config)
|
||||
self._include_headers = include_headers
|
||||
self._error_message = error_message
|
||||
self._status_code = status_code
|
||||
|
||||
async def dispatch(
|
||||
self,
|
||||
request: Request,
|
||||
call_next: Callable[[Request], Awaitable[Response]],
|
||||
) -> Response:
|
||||
"""Process the request with rate limiting."""
|
||||
try:
|
||||
result = await self._limiter.check(request, self._config)
|
||||
|
||||
if not result.allowed:
|
||||
return self._create_rate_limit_response(result)
|
||||
|
||||
response = await call_next(request)
|
||||
|
||||
if self._include_headers:
|
||||
for key, value in result.info.to_headers().items():
|
||||
response.headers[key] = value
|
||||
|
||||
return response
|
||||
|
||||
except RateLimitExceeded as exc:
|
||||
return JSONResponse(
|
||||
status_code=self._status_code,
|
||||
content={
|
||||
"detail": exc.message,
|
||||
"retry_after": exc.retry_after,
|
||||
},
|
||||
headers=exc.limit_info.to_headers() if exc.limit_info else {},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error in rate limit middleware: %s", e)
|
||||
if self._config.skip_on_error:
|
||||
return await call_next(request)
|
||||
raise
|
||||
|
||||
def _create_rate_limit_response(self, result: object) -> JSONResponse:
|
||||
"""Create a rate limit exceeded response."""
|
||||
from fastapi_traffic.core.models import RateLimitResult
|
||||
|
||||
if isinstance(result, RateLimitResult):
|
||||
headers = result.info.to_headers()
|
||||
retry_after = result.info.retry_after
|
||||
else:
|
||||
headers = {}
|
||||
retry_after = None
|
||||
|
||||
return JSONResponse(
|
||||
status_code=self._status_code,
|
||||
content={
|
||||
"detail": self._error_message,
|
||||
"retry_after": retry_after,
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
|
||||
class SlidingWindowMiddleware(RateLimitMiddleware):
|
||||
"""Convenience middleware using sliding window algorithm."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
*,
|
||||
limit: int = 100,
|
||||
window_size: float = 60.0,
|
||||
**kwargs: object,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
app,
|
||||
limit=limit,
|
||||
window_size=window_size,
|
||||
algorithm=Algorithm.SLIDING_WINDOW,
|
||||
**kwargs, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
class TokenBucketMiddleware(RateLimitMiddleware):
|
||||
"""Convenience middleware using token bucket algorithm."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
*,
|
||||
limit: int = 100,
|
||||
window_size: float = 60.0,
|
||||
**kwargs: object,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
app,
|
||||
limit=limit,
|
||||
window_size=window_size,
|
||||
algorithm=Algorithm.TOKEN_BUCKET,
|
||||
**kwargs, # type: ignore[arg-type]
|
||||
)
|
||||
Reference in New Issue
Block a user