"""Rate limit decorator for FastAPI endpoints.""" from __future__ import annotations import functools from collections.abc import Callable from typing import TYPE_CHECKING, Any, TypeVar, overload from fastapi_traffic.core.algorithms import Algorithm from fastapi_traffic.core.config import ( KeyExtractor, RateLimitConfig, default_key_extractor, ) from fastapi_traffic.core.limiter import get_limiter if TYPE_CHECKING: from starlette.requests import Request from starlette.responses import Response from fastapi_traffic.exceptions import RateLimitExceeded F = TypeVar("F", bound=Callable[..., Any]) # Note: Config loader from secrets .env @overload def rate_limit( limit: int, *, window_size: float = ..., algorithm: Algorithm = ..., key_prefix: str = ..., key_extractor: KeyExtractor = ..., burst_size: int | None = ..., include_headers: bool = ..., error_message: str = ..., status_code: int = ..., skip_on_error: bool = ..., cost: int = ..., exempt_when: Callable[[Request], bool] | None = ..., on_blocked: Callable[[Request, Any], Any] | None = ..., ) -> Callable[[F], F]: ... @overload def rate_limit( limit: int, window_size: float, /, ) -> Callable[[F], F]: ... def rate_limit( limit: int, window_size: float = 60.0, *, algorithm: Algorithm = Algorithm.SLIDING_WINDOW_COUNTER, key_prefix: str = "ratelimit", key_extractor: KeyExtractor = default_key_extractor, burst_size: int | None = None, include_headers: bool = True, error_message: str = "Rate limit exceeded", status_code: int = 429, skip_on_error: bool = False, cost: int = 1, exempt_when: Callable[[Request], bool] | None = None, on_blocked: Callable[[Request, Any], Any] | None = None, ) -> Callable[[F], F]: """Decorator to apply rate limiting to a FastAPI endpoint. Args: limit: Maximum number of requests allowed in the window. window_size: Time window in seconds. algorithm: Rate limiting algorithm to use. key_prefix: Prefix for the rate limit key. key_extractor: Function to extract the client identifier from request. burst_size: Maximum burst size (for token bucket/leaky bucket). include_headers: Whether to include rate limit headers in response. error_message: Error message when rate limit is exceeded. status_code: HTTP status code when rate limit is exceeded. skip_on_error: Skip rate limiting if backend errors occur. cost: Cost of each request (default 1). exempt_when: Function to determine if request should be exempt. on_blocked: Callback when a request is blocked. Returns: Decorated function with rate limiting applied. Example: ```python from fastapi import FastAPI from fastapi_traffic import rate_limit app = FastAPI() @app.get("/api/resource") @rate_limit(100, 60) # 100 requests per minute async def get_resource(): return {"message": "Hello"} ``` """ config = RateLimitConfig( limit=limit, window_size=window_size, algorithm=algorithm, key_prefix=key_prefix, key_extractor=key_extractor, burst_size=burst_size, include_headers=include_headers, error_message=error_message, status_code=status_code, skip_on_error=skip_on_error, cost=cost, exempt_when=exempt_when, on_blocked=on_blocked, ) def decorator(func: F) -> F: @functools.wraps(func) async def async_wrapper(*args: Any, **kwargs: Any) -> Any: request = _extract_request(args, kwargs) if request is None: return await func(*args, **kwargs) limiter = get_limiter() result = await limiter.hit(request, config) response = await func(*args, **kwargs) if config.include_headers and hasattr(response, "headers"): for key, value in result.info.to_headers().items(): response.headers[key] = value return response @functools.wraps(func) def sync_wrapper(*args: Any, **kwargs: Any) -> Any: import asyncio return asyncio.get_event_loop().run_until_complete( async_wrapper(*args, **kwargs) ) if _is_coroutine_function(func): return async_wrapper # type: ignore[return-value] return sync_wrapper # type: ignore[return-value] return decorator def _extract_request( args: tuple[Any, ...], kwargs: dict[str, Any], ) -> Request | None: """Extract the Request object from function arguments.""" from starlette.requests import Request for arg in args: if isinstance(arg, Request): return arg for value in kwargs.values(): if isinstance(value, Request): return value if "request" in kwargs: req = kwargs["request"] if isinstance(req, Request): return req return None def _is_coroutine_function(func: Callable[..., Any]) -> bool: """Check if a function is a coroutine function.""" import asyncio import inspect return asyncio.iscoroutinefunction(func) or inspect.iscoroutinefunction(func) class RateLimitDependency: """FastAPI dependency for rate limiting. Example: ```python from fastapi import FastAPI, Depends from fastapi_traffic import RateLimitDependency app = FastAPI() rate_limiter = RateLimitDependency(limit=100, window_size=60) @app.get("/api/resource") async def get_resource(rate_limit_info = Depends(rate_limiter)): return {"remaining": rate_limit_info.remaining} ``` """ __slots__ = ("_config",) def __init__( self, limit: int, window_size: float = 60.0, *, algorithm: Algorithm = Algorithm.SLIDING_WINDOW_COUNTER, key_prefix: str = "ratelimit", key_extractor: KeyExtractor = default_key_extractor, burst_size: int | None = None, error_message: str = "Rate limit exceeded", status_code: int = 429, skip_on_error: bool = False, cost: int = 1, exempt_when: Callable[[Request], bool] | None = None, ) -> None: self._config = RateLimitConfig( limit=limit, window_size=window_size, algorithm=algorithm, key_prefix=key_prefix, key_extractor=key_extractor, burst_size=burst_size, include_headers=True, error_message=error_message, status_code=status_code, skip_on_error=skip_on_error, cost=cost, exempt_when=exempt_when, ) async def __call__(self, request: Request) -> Any: """Check rate limit and return info.""" limiter = get_limiter() result = await limiter.hit(request, self._config) return result.info def create_rate_limit_response( exc: RateLimitExceeded, *, include_headers: bool = True, ) -> Response: """Create a rate limit exceeded response. Args: exc: The RateLimitExceeded exception. include_headers: Whether to include rate limit headers. Returns: A JSONResponse with rate limit information. """ from starlette.responses import JSONResponse headers: dict[str, str] = {} if include_headers and exc.limit_info is not None: headers = exc.limit_info.to_headers() return JSONResponse( status_code=429, content={ "detail": exc.message, "retry_after": exc.retry_after, }, headers=headers, )