Initial commit: fastapi-traffic rate limiting library
- Core rate limiting with multiple algorithms (sliding window, token bucket, etc.) - SQLite and memory backends - Decorator and dependency injection patterns - Middleware support - Example usage files
This commit is contained in:
259
fastapi_traffic/core/decorator.py
Normal file
259
fastapi_traffic/core/decorator.py
Normal file
@@ -0,0 +1,259 @@
|
||||
"""Rate limit decorator for FastAPI endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from typing import TYPE_CHECKING, Any, Callable, TypeVar, overload
|
||||
|
||||
from fastapi_traffic.core.algorithms import Algorithm
|
||||
from fastapi_traffic.core.config import KeyExtractor, RateLimitConfig, default_key_extractor
|
||||
from fastapi_traffic.core.limiter import get_limiter
|
||||
from fastapi_traffic.exceptions import RateLimitExceeded
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
@overload
|
||||
def rate_limit(
|
||||
limit: int,
|
||||
*,
|
||||
window_size: float = ...,
|
||||
algorithm: Algorithm = ...,
|
||||
key_prefix: str = ...,
|
||||
key_extractor: KeyExtractor = ...,
|
||||
burst_size: int | None = ...,
|
||||
include_headers: bool = ...,
|
||||
error_message: str = ...,
|
||||
status_code: int = ...,
|
||||
skip_on_error: bool = ...,
|
||||
cost: int = ...,
|
||||
exempt_when: Callable[[Request], bool] | None = ...,
|
||||
on_blocked: Callable[[Request, Any], Any] | None = ...,
|
||||
) -> Callable[[F], F]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def rate_limit(
|
||||
limit: int,
|
||||
window_size: float,
|
||||
/,
|
||||
) -> Callable[[F], F]: ...
|
||||
|
||||
|
||||
def rate_limit(
|
||||
limit: int,
|
||||
window_size: float = 60.0,
|
||||
*,
|
||||
algorithm: Algorithm = Algorithm.SLIDING_WINDOW_COUNTER,
|
||||
key_prefix: str = "ratelimit",
|
||||
key_extractor: KeyExtractor = default_key_extractor,
|
||||
burst_size: int | None = None,
|
||||
include_headers: bool = True,
|
||||
error_message: str = "Rate limit exceeded",
|
||||
status_code: int = 429,
|
||||
skip_on_error: bool = False,
|
||||
cost: int = 1,
|
||||
exempt_when: Callable[[Request], bool] | None = None,
|
||||
on_blocked: Callable[[Request, Any], Any] | None = None,
|
||||
) -> Callable[[F], F]:
|
||||
"""Decorator to apply rate limiting to a FastAPI endpoint.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of requests allowed in the window.
|
||||
window_size: Time window in seconds.
|
||||
algorithm: Rate limiting algorithm to use.
|
||||
key_prefix: Prefix for the rate limit key.
|
||||
key_extractor: Function to extract the client identifier from request.
|
||||
burst_size: Maximum burst size (for token bucket/leaky bucket).
|
||||
include_headers: Whether to include rate limit headers in response.
|
||||
error_message: Error message when rate limit is exceeded.
|
||||
status_code: HTTP status code when rate limit is exceeded.
|
||||
skip_on_error: Skip rate limiting if backend errors occur.
|
||||
cost: Cost of each request (default 1).
|
||||
exempt_when: Function to determine if request should be exempt.
|
||||
on_blocked: Callback when a request is blocked.
|
||||
|
||||
Returns:
|
||||
Decorated function with rate limiting applied.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from fastapi import FastAPI
|
||||
from fastapi_traffic import rate_limit
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/api/resource")
|
||||
@rate_limit(100, 60) # 100 requests per minute
|
||||
async def get_resource():
|
||||
return {"message": "Hello"}
|
||||
```
|
||||
"""
|
||||
config = RateLimitConfig(
|
||||
limit=limit,
|
||||
window_size=window_size,
|
||||
algorithm=algorithm,
|
||||
key_prefix=key_prefix,
|
||||
key_extractor=key_extractor,
|
||||
burst_size=burst_size,
|
||||
include_headers=include_headers,
|
||||
error_message=error_message,
|
||||
status_code=status_code,
|
||||
skip_on_error=skip_on_error,
|
||||
cost=cost,
|
||||
exempt_when=exempt_when,
|
||||
on_blocked=on_blocked,
|
||||
)
|
||||
|
||||
def decorator(func: F) -> F:
|
||||
@functools.wraps(func)
|
||||
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
request = _extract_request(args, kwargs)
|
||||
if request is None:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
limiter = get_limiter()
|
||||
result = await limiter.hit(request, config)
|
||||
|
||||
response = await func(*args, **kwargs)
|
||||
|
||||
if config.include_headers and hasattr(response, "headers"):
|
||||
for key, value in result.info.to_headers().items():
|
||||
response.headers[key] = value
|
||||
|
||||
return response
|
||||
|
||||
@functools.wraps(func)
|
||||
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
import asyncio
|
||||
|
||||
return asyncio.get_event_loop().run_until_complete(
|
||||
async_wrapper(*args, **kwargs)
|
||||
)
|
||||
|
||||
if _is_coroutine_function(func):
|
||||
return async_wrapper # type: ignore[return-value]
|
||||
return sync_wrapper # type: ignore[return-value]
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _extract_request(
|
||||
args: tuple[Any, ...],
|
||||
kwargs: dict[str, Any],
|
||||
) -> Request | None:
|
||||
"""Extract the Request object from function arguments."""
|
||||
from starlette.requests import Request
|
||||
|
||||
for arg in args:
|
||||
if isinstance(arg, Request):
|
||||
return arg
|
||||
|
||||
for value in kwargs.values():
|
||||
if isinstance(value, Request):
|
||||
return value
|
||||
|
||||
if "request" in kwargs:
|
||||
req = kwargs["request"]
|
||||
if isinstance(req, Request):
|
||||
return req
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _is_coroutine_function(func: Callable[..., Any]) -> bool:
|
||||
"""Check if a function is a coroutine function."""
|
||||
import asyncio
|
||||
import inspect
|
||||
|
||||
return asyncio.iscoroutinefunction(func) or inspect.iscoroutinefunction(func)
|
||||
|
||||
|
||||
class RateLimitDependency:
|
||||
"""FastAPI dependency for rate limiting.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from fastapi import FastAPI, Depends
|
||||
from fastapi_traffic import RateLimitDependency
|
||||
|
||||
app = FastAPI()
|
||||
rate_limiter = RateLimitDependency(limit=100, window_size=60)
|
||||
|
||||
@app.get("/api/resource")
|
||||
async def get_resource(rate_limit_info = Depends(rate_limiter)):
|
||||
return {"remaining": rate_limit_info.remaining}
|
||||
```
|
||||
"""
|
||||
|
||||
__slots__ = ("_config",)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
limit: int,
|
||||
window_size: float = 60.0,
|
||||
*,
|
||||
algorithm: Algorithm = Algorithm.SLIDING_WINDOW_COUNTER,
|
||||
key_prefix: str = "ratelimit",
|
||||
key_extractor: KeyExtractor = default_key_extractor,
|
||||
burst_size: int | None = None,
|
||||
error_message: str = "Rate limit exceeded",
|
||||
status_code: int = 429,
|
||||
skip_on_error: bool = False,
|
||||
cost: int = 1,
|
||||
exempt_when: Callable[[Request], bool] | None = None,
|
||||
) -> None:
|
||||
self._config = RateLimitConfig(
|
||||
limit=limit,
|
||||
window_size=window_size,
|
||||
algorithm=algorithm,
|
||||
key_prefix=key_prefix,
|
||||
key_extractor=key_extractor,
|
||||
burst_size=burst_size,
|
||||
include_headers=True,
|
||||
error_message=error_message,
|
||||
status_code=status_code,
|
||||
skip_on_error=skip_on_error,
|
||||
cost=cost,
|
||||
exempt_when=exempt_when,
|
||||
)
|
||||
|
||||
async def __call__(self, request: Request) -> Any:
|
||||
"""Check rate limit and return info."""
|
||||
limiter = get_limiter()
|
||||
result = await limiter.hit(request, self._config)
|
||||
return result.info
|
||||
|
||||
|
||||
def create_rate_limit_response(
|
||||
exc: RateLimitExceeded,
|
||||
*,
|
||||
include_headers: bool = True,
|
||||
) -> Response:
|
||||
"""Create a rate limit exceeded response.
|
||||
|
||||
Args:
|
||||
exc: The RateLimitExceeded exception.
|
||||
include_headers: Whether to include rate limit headers.
|
||||
|
||||
Returns:
|
||||
A JSONResponse with rate limit information.
|
||||
"""
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
headers: dict[str, str] = {}
|
||||
if include_headers and exc.limit_info is not None:
|
||||
headers = exc.limit_info.to_headers()
|
||||
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"detail": exc.message,
|
||||
"retry_after": exc.retry_after,
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
Reference in New Issue
Block a user