"""Rate limiting middleware for Starlette/FastAPI applications.""" from __future__ import annotations import logging from typing import TYPE_CHECKING, Awaitable, Callable from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import JSONResponse from fastapi_traffic.backends.memory import MemoryBackend from fastapi_traffic.core.algorithms import Algorithm from fastapi_traffic.core.config import GlobalConfig, RateLimitConfig, default_key_extractor from fastapi_traffic.core.limiter import RateLimiter from fastapi_traffic.exceptions import RateLimitExceeded if TYPE_CHECKING: from starlette.requests import Request from starlette.responses import Response from starlette.types import ASGIApp from fastapi_traffic.backends.base import Backend logger = logging.getLogger(__name__) class RateLimitMiddleware(BaseHTTPMiddleware): """Middleware for global rate limiting across all endpoints.""" def __init__( self, app: ASGIApp, *, limit: int = 100, window_size: float = 60.0, algorithm: Algorithm = Algorithm.SLIDING_WINDOW_COUNTER, backend: Backend | None = None, key_prefix: str = "middleware", include_headers: bool = True, error_message: str = "Rate limit exceeded. Please try again later.", status_code: int = 429, skip_on_error: bool = False, exempt_paths: set[str] | None = None, exempt_ips: set[str] | None = None, key_extractor: Callable[[Request], str] = default_key_extractor, ) -> None: """Initialize the rate limit middleware. Args: app: The ASGI application. limit: Maximum requests per window. window_size: Time window in seconds. algorithm: Rate limiting algorithm. backend: Storage backend (defaults to MemoryBackend). key_prefix: Prefix for rate limit keys. include_headers: Include rate limit headers in response. error_message: Error message when rate limited. status_code: HTTP status code when rate limited. skip_on_error: Skip rate limiting on backend errors. exempt_paths: Paths to exempt from rate limiting. exempt_ips: IP addresses to exempt from rate limiting. key_extractor: Function to extract client identifier. """ super().__init__(app) self._backend = backend or MemoryBackend() self._config = RateLimitConfig( limit=limit, window_size=window_size, algorithm=algorithm, key_prefix=key_prefix, key_extractor=key_extractor, include_headers=include_headers, error_message=error_message, status_code=status_code, skip_on_error=skip_on_error, ) global_config = GlobalConfig( backend=self._backend, exempt_paths=exempt_paths or set(), exempt_ips=exempt_ips or set(), ) self._limiter = RateLimiter(self._backend, config=global_config) self._include_headers = include_headers self._error_message = error_message self._status_code = status_code async def dispatch( self, request: Request, call_next: Callable[[Request], Awaitable[Response]], ) -> Response: """Process the request with rate limiting.""" try: result = await self._limiter.check(request, self._config) if not result.allowed: return self._create_rate_limit_response(result) response = await call_next(request) if self._include_headers: for key, value in result.info.to_headers().items(): response.headers[key] = value return response except RateLimitExceeded as exc: return JSONResponse( status_code=self._status_code, content={ "detail": exc.message, "retry_after": exc.retry_after, }, headers=exc.limit_info.to_headers() if exc.limit_info else {}, ) except Exception as e: logger.exception("Error in rate limit middleware: %s", e) if self._config.skip_on_error: return await call_next(request) raise def _create_rate_limit_response(self, result: object) -> JSONResponse: """Create a rate limit exceeded response.""" from fastapi_traffic.core.models import RateLimitResult if isinstance(result, RateLimitResult): headers = result.info.to_headers() retry_after = result.info.retry_after else: headers = {} retry_after = None return JSONResponse( status_code=self._status_code, content={ "detail": self._error_message, "retry_after": retry_after, }, headers=headers, ) class SlidingWindowMiddleware(RateLimitMiddleware): """Convenience middleware using sliding window algorithm.""" def __init__( self, app: ASGIApp, *, limit: int = 100, window_size: float = 60.0, **kwargs: object, ) -> None: super().__init__( app, limit=limit, window_size=window_size, algorithm=Algorithm.SLIDING_WINDOW, **kwargs, # type: ignore[arg-type] ) class TokenBucketMiddleware(RateLimitMiddleware): """Convenience middleware using token bucket algorithm.""" def __init__( self, app: ASGIApp, *, limit: int = 100, window_size: float = 60.0, **kwargs: object, ) -> None: super().__init__( app, limit=limit, window_size=window_size, algorithm=Algorithm.TOKEN_BUCKET, **kwargs, # type: ignore[arg-type] )