Initial commit: fastapi-traffic rate limiting library

- Core rate limiting with multiple algorithms (sliding window, token bucket, etc.)
- SQLite and memory backends
- Decorator and dependency injection patterns
- Middleware support
- Example usage files
This commit is contained in:
2026-01-09 00:26:19 +00:00
commit da496746bb
38 changed files with 5790 additions and 0 deletions

View File

@@ -0,0 +1,16 @@
"""Core rate limiting components."""
from fastapi_traffic.core.algorithms import Algorithm
from fastapi_traffic.core.config import RateLimitConfig
from fastapi_traffic.core.decorator import rate_limit
from fastapi_traffic.core.limiter import RateLimiter
from fastapi_traffic.core.models import RateLimitInfo, RateLimitResult
__all__ = [
"Algorithm",
"RateLimitConfig",
"rate_limit",
"RateLimiter",
"RateLimitInfo",
"RateLimitResult",
]

View File

@@ -0,0 +1,466 @@
"""Rate limiting algorithms implementation."""
from __future__ import annotations
import time
from abc import ABC, abstractmethod
from enum import Enum
from typing import TYPE_CHECKING
from fastapi_traffic.core.models import RateLimitInfo
if TYPE_CHECKING:
from fastapi_traffic.backends.base import Backend
class Algorithm(str, Enum):
"""Available rate limiting algorithms."""
TOKEN_BUCKET = "token_bucket"
SLIDING_WINDOW = "sliding_window"
FIXED_WINDOW = "fixed_window"
LEAKY_BUCKET = "leaky_bucket"
SLIDING_WINDOW_COUNTER = "sliding_window_counter"
class BaseAlgorithm(ABC):
"""Base class for rate limiting algorithms."""
__slots__ = ("limit", "window_size", "backend", "burst_size")
def __init__(
self,
limit: int,
window_size: float,
backend: Backend,
*,
burst_size: int | None = None,
) -> None:
self.limit = limit
self.window_size = window_size
self.backend = backend
self.burst_size = burst_size or limit
@abstractmethod
async def check(self, key: str) -> tuple[bool, RateLimitInfo]:
"""Check if request is allowed and update state."""
...
@abstractmethod
async def reset(self, key: str) -> None:
"""Reset the rate limit state for a key."""
...
@abstractmethod
async def get_state(self, key: str) -> RateLimitInfo | None:
"""Get current state without consuming a token."""
...
class TokenBucketAlgorithm(BaseAlgorithm):
"""Token bucket algorithm - allows bursts up to bucket capacity."""
__slots__ = ("refill_rate",)
def __init__(
self,
limit: int,
window_size: float,
backend: Backend,
*,
burst_size: int | None = None,
) -> None:
super().__init__(limit, window_size, backend, burst_size=burst_size)
self.refill_rate = limit / window_size
async def check(self, key: str) -> tuple[bool, RateLimitInfo]:
now = time.time()
state = await self.backend.get(key)
if state is None:
tokens = float(self.burst_size - 1)
await self.backend.set(
key,
{"tokens": tokens, "last_update": now},
ttl=self.window_size * 2,
)
return True, RateLimitInfo(
limit=self.limit,
remaining=int(tokens),
reset_at=now + self.window_size,
window_size=self.window_size,
)
tokens = float(state.get("tokens", self.burst_size))
last_update = float(state.get("last_update", now))
elapsed = now - last_update
tokens = min(self.burst_size, tokens + elapsed * self.refill_rate)
if tokens >= 1:
tokens -= 1
allowed = True
retry_after = None
else:
allowed = False
retry_after = (1 - tokens) / self.refill_rate
await self.backend.set(
key,
{"tokens": tokens, "last_update": now},
ttl=self.window_size * 2,
)
return allowed, RateLimitInfo(
limit=self.limit,
remaining=int(tokens),
reset_at=now + (self.burst_size - tokens) / self.refill_rate,
retry_after=retry_after,
window_size=self.window_size,
)
async def reset(self, key: str) -> None:
await self.backend.delete(key)
async def get_state(self, key: str) -> RateLimitInfo | None:
now = time.time()
state = await self.backend.get(key)
if state is None:
return None
tokens = float(state.get("tokens", self.burst_size))
last_update = float(state.get("last_update", now))
elapsed = now - last_update
tokens = min(self.burst_size, tokens + elapsed * self.refill_rate)
return RateLimitInfo(
limit=self.limit,
remaining=int(tokens),
reset_at=now + (self.burst_size - tokens) / self.refill_rate,
window_size=self.window_size,
)
class SlidingWindowAlgorithm(BaseAlgorithm):
"""Sliding window log algorithm - precise but memory intensive."""
async def check(self, key: str) -> tuple[bool, RateLimitInfo]:
now = time.time()
window_start = now - self.window_size
state = await self.backend.get(key)
timestamps: list[float] = []
if state is not None:
raw_timestamps = state.get("timestamps", [])
timestamps = [
float(ts) for ts in raw_timestamps if float(ts) > window_start
]
if len(timestamps) < self.limit:
timestamps.append(now)
allowed = True
retry_after = None
else:
allowed = False
oldest = min(timestamps) if timestamps else now
retry_after = oldest + self.window_size - now
await self.backend.set(
key,
{"timestamps": timestamps},
ttl=self.window_size * 2,
)
remaining = max(0, self.limit - len(timestamps))
reset_at = (min(timestamps) if timestamps else now) + self.window_size
return allowed, RateLimitInfo(
limit=self.limit,
remaining=remaining,
reset_at=reset_at,
retry_after=retry_after,
window_size=self.window_size,
)
async def reset(self, key: str) -> None:
await self.backend.delete(key)
async def get_state(self, key: str) -> RateLimitInfo | None:
now = time.time()
window_start = now - self.window_size
state = await self.backend.get(key)
if state is None:
return None
raw_timestamps = state.get("timestamps", [])
timestamps = [float(ts) for ts in raw_timestamps if float(ts) > window_start]
remaining = max(0, self.limit - len(timestamps))
reset_at = (min(timestamps) if timestamps else now) + self.window_size
return RateLimitInfo(
limit=self.limit,
remaining=remaining,
reset_at=reset_at,
window_size=self.window_size,
)
class FixedWindowAlgorithm(BaseAlgorithm):
"""Fixed window algorithm - simple and efficient."""
async def check(self, key: str) -> tuple[bool, RateLimitInfo]:
now = time.time()
window_start = (now // self.window_size) * self.window_size
window_end = window_start + self.window_size
state = await self.backend.get(key)
count = 0
if state is not None:
stored_window = float(state.get("window_start", 0))
if stored_window == window_start:
count = int(state.get("count", 0))
if count < self.limit:
count += 1
allowed = True
retry_after = None
else:
allowed = False
retry_after = window_end - now
await self.backend.set(
key,
{"count": count, "window_start": window_start},
ttl=self.window_size * 2,
)
return allowed, RateLimitInfo(
limit=self.limit,
remaining=max(0, self.limit - count),
reset_at=window_end,
retry_after=retry_after,
window_size=self.window_size,
)
async def reset(self, key: str) -> None:
await self.backend.delete(key)
async def get_state(self, key: str) -> RateLimitInfo | None:
now = time.time()
window_start = (now // self.window_size) * self.window_size
window_end = window_start + self.window_size
state = await self.backend.get(key)
if state is None:
return None
count = 0
stored_window = float(state.get("window_start", 0))
if stored_window == window_start:
count = int(state.get("count", 0))
return RateLimitInfo(
limit=self.limit,
remaining=max(0, self.limit - count),
reset_at=window_end,
window_size=self.window_size,
)
class LeakyBucketAlgorithm(BaseAlgorithm):
"""Leaky bucket algorithm - smooths out bursts."""
__slots__ = ("leak_rate",)
def __init__(
self,
limit: int,
window_size: float,
backend: Backend,
*,
burst_size: int | None = None,
) -> None:
super().__init__(limit, window_size, backend, burst_size=burst_size)
self.leak_rate = limit / window_size
async def check(self, key: str) -> tuple[bool, RateLimitInfo]:
now = time.time()
state = await self.backend.get(key)
water_level = 0.0
if state is not None:
water_level = float(state.get("water_level", 0))
last_update = float(state.get("last_update", now))
elapsed = now - last_update
water_level = max(0, water_level - elapsed * self.leak_rate)
if water_level < self.burst_size:
water_level += 1
allowed = True
retry_after = None
else:
allowed = False
retry_after = (water_level - self.burst_size + 1) / self.leak_rate
await self.backend.set(
key,
{"water_level": water_level, "last_update": now},
ttl=self.window_size * 2,
)
remaining = max(0, int(self.burst_size - water_level))
reset_at = now + water_level / self.leak_rate
return allowed, RateLimitInfo(
limit=self.limit,
remaining=remaining,
reset_at=reset_at,
retry_after=retry_after,
window_size=self.window_size,
)
async def reset(self, key: str) -> None:
await self.backend.delete(key)
async def get_state(self, key: str) -> RateLimitInfo | None:
now = time.time()
state = await self.backend.get(key)
if state is None:
return None
water_level = float(state.get("water_level", 0))
last_update = float(state.get("last_update", now))
elapsed = now - last_update
water_level = max(0, water_level - elapsed * self.leak_rate)
remaining = max(0, int(self.burst_size - water_level))
reset_at = now + water_level / self.leak_rate
return RateLimitInfo(
limit=self.limit,
remaining=remaining,
reset_at=reset_at,
window_size=self.window_size,
)
class SlidingWindowCounterAlgorithm(BaseAlgorithm):
"""Sliding window counter - balance between precision and memory."""
async def check(self, key: str) -> tuple[bool, RateLimitInfo]:
now = time.time()
current_window = (now // self.window_size) * self.window_size
previous_window = current_window - self.window_size
window_progress = (now - current_window) / self.window_size
state = await self.backend.get(key)
prev_count = 0
curr_count = 0
if state is not None:
prev_count = int(state.get("prev_count", 0))
curr_count = int(state.get("curr_count", 0))
stored_window = float(state.get("current_window", 0))
if stored_window < previous_window:
prev_count = 0
curr_count = 0
elif stored_window == previous_window:
prev_count = curr_count
curr_count = 0
weighted_count = prev_count * (1 - window_progress) + curr_count
if weighted_count < self.limit:
curr_count += 1
allowed = True
retry_after = None
else:
allowed = False
retry_after = self.window_size * (1 - window_progress)
await self.backend.set(
key,
{
"prev_count": prev_count,
"curr_count": curr_count,
"current_window": current_window,
},
ttl=self.window_size * 3,
)
remaining = max(0, int(self.limit - weighted_count))
reset_at = current_window + self.window_size
return allowed, RateLimitInfo(
limit=self.limit,
remaining=remaining,
reset_at=reset_at,
retry_after=retry_after,
window_size=self.window_size,
)
async def reset(self, key: str) -> None:
await self.backend.delete(key)
async def get_state(self, key: str) -> RateLimitInfo | None:
now = time.time()
current_window = (now // self.window_size) * self.window_size
previous_window = current_window - self.window_size
window_progress = (now - current_window) / self.window_size
state = await self.backend.get(key)
if state is None:
return None
prev_count = int(state.get("prev_count", 0))
curr_count = int(state.get("curr_count", 0))
stored_window = float(state.get("current_window", 0))
if stored_window < previous_window:
prev_count = 0
curr_count = 0
elif stored_window == previous_window:
prev_count = curr_count
curr_count = 0
weighted_count = prev_count * (1 - window_progress) + curr_count
remaining = max(0, int(self.limit - weighted_count))
reset_at = current_window + self.window_size
return RateLimitInfo(
limit=self.limit,
remaining=remaining,
reset_at=reset_at,
window_size=self.window_size,
)
def get_algorithm(
algorithm: Algorithm,
limit: int,
window_size: float,
backend: Backend,
*,
burst_size: int | None = None,
) -> BaseAlgorithm:
"""Factory function to create algorithm instances."""
algorithm_map: dict[Algorithm, type[BaseAlgorithm]] = {
Algorithm.TOKEN_BUCKET: TokenBucketAlgorithm,
Algorithm.SLIDING_WINDOW: SlidingWindowAlgorithm,
Algorithm.FIXED_WINDOW: FixedWindowAlgorithm,
Algorithm.LEAKY_BUCKET: LeakyBucketAlgorithm,
Algorithm.SLIDING_WINDOW_COUNTER: SlidingWindowCounterAlgorithm,
}
algorithm_class = algorithm_map.get(algorithm)
if algorithm_class is None:
msg = f"Unknown algorithm: {algorithm}"
raise ValueError(msg)
return algorithm_class(limit, window_size, backend, burst_size=burst_size)

View File

@@ -0,0 +1,81 @@
"""Configuration for rate limiting."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable
from fastapi_traffic.core.algorithms import Algorithm
if TYPE_CHECKING:
from starlette.requests import Request
from fastapi_traffic.backends.base import Backend
KeyExtractor = Callable[["Request"], str]
def default_key_extractor(request: Request) -> str:
"""Extract client IP as the default rate limit key."""
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
return forwarded.split(",")[0].strip()
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip
if request.client:
return request.client.host
return "unknown"
@dataclass(slots=True)
class RateLimitConfig:
"""Configuration for a rate limit rule."""
limit: int
window_size: float = 60.0
algorithm: Algorithm = Algorithm.SLIDING_WINDOW_COUNTER
key_prefix: str = "ratelimit"
key_extractor: KeyExtractor = field(default=default_key_extractor)
burst_size: int | None = None
include_headers: bool = True
error_message: str = "Rate limit exceeded"
status_code: int = 429
skip_on_error: bool = False
cost: int = 1
exempt_when: Callable[[Request], bool] | None = None
on_blocked: Callable[[Request, Any], Any] | None = None
def __post_init__(self) -> None:
if self.limit <= 0:
msg = "limit must be positive"
raise ValueError(msg)
if self.window_size <= 0:
msg = "window_size must be positive"
raise ValueError(msg)
if self.cost <= 0:
msg = "cost must be positive"
raise ValueError(msg)
@dataclass(slots=True)
class GlobalConfig:
"""Global configuration for the rate limiter."""
backend: Backend | None = None
enabled: bool = True
default_limit: int = 100
default_window_size: float = 60.0
default_algorithm: Algorithm = Algorithm.SLIDING_WINDOW_COUNTER
key_prefix: str = "fastapi_traffic"
include_headers: bool = True
error_message: str = "Rate limit exceeded. Please try again later."
status_code: int = 429
skip_on_error: bool = False
exempt_ips: set[str] = field(default_factory=set)
exempt_paths: set[str] = field(default_factory=set)
headers_prefix: str = "X-RateLimit"

View File

@@ -0,0 +1,259 @@
"""Rate limit decorator for FastAPI endpoints."""
from __future__ import annotations
import functools
from typing import TYPE_CHECKING, Any, Callable, TypeVar, overload
from fastapi_traffic.core.algorithms import Algorithm
from fastapi_traffic.core.config import KeyExtractor, RateLimitConfig, default_key_extractor
from fastapi_traffic.core.limiter import get_limiter
from fastapi_traffic.exceptions import RateLimitExceeded
if TYPE_CHECKING:
from starlette.requests import Request
from starlette.responses import Response
F = TypeVar("F", bound=Callable[..., Any])
@overload
def rate_limit(
limit: int,
*,
window_size: float = ...,
algorithm: Algorithm = ...,
key_prefix: str = ...,
key_extractor: KeyExtractor = ...,
burst_size: int | None = ...,
include_headers: bool = ...,
error_message: str = ...,
status_code: int = ...,
skip_on_error: bool = ...,
cost: int = ...,
exempt_when: Callable[[Request], bool] | None = ...,
on_blocked: Callable[[Request, Any], Any] | None = ...,
) -> Callable[[F], F]: ...
@overload
def rate_limit(
limit: int,
window_size: float,
/,
) -> Callable[[F], F]: ...
def rate_limit(
limit: int,
window_size: float = 60.0,
*,
algorithm: Algorithm = Algorithm.SLIDING_WINDOW_COUNTER,
key_prefix: str = "ratelimit",
key_extractor: KeyExtractor = default_key_extractor,
burst_size: int | None = None,
include_headers: bool = True,
error_message: str = "Rate limit exceeded",
status_code: int = 429,
skip_on_error: bool = False,
cost: int = 1,
exempt_when: Callable[[Request], bool] | None = None,
on_blocked: Callable[[Request, Any], Any] | None = None,
) -> Callable[[F], F]:
"""Decorator to apply rate limiting to a FastAPI endpoint.
Args:
limit: Maximum number of requests allowed in the window.
window_size: Time window in seconds.
algorithm: Rate limiting algorithm to use.
key_prefix: Prefix for the rate limit key.
key_extractor: Function to extract the client identifier from request.
burst_size: Maximum burst size (for token bucket/leaky bucket).
include_headers: Whether to include rate limit headers in response.
error_message: Error message when rate limit is exceeded.
status_code: HTTP status code when rate limit is exceeded.
skip_on_error: Skip rate limiting if backend errors occur.
cost: Cost of each request (default 1).
exempt_when: Function to determine if request should be exempt.
on_blocked: Callback when a request is blocked.
Returns:
Decorated function with rate limiting applied.
Example:
```python
from fastapi import FastAPI
from fastapi_traffic import rate_limit
app = FastAPI()
@app.get("/api/resource")
@rate_limit(100, 60) # 100 requests per minute
async def get_resource():
return {"message": "Hello"}
```
"""
config = RateLimitConfig(
limit=limit,
window_size=window_size,
algorithm=algorithm,
key_prefix=key_prefix,
key_extractor=key_extractor,
burst_size=burst_size,
include_headers=include_headers,
error_message=error_message,
status_code=status_code,
skip_on_error=skip_on_error,
cost=cost,
exempt_when=exempt_when,
on_blocked=on_blocked,
)
def decorator(func: F) -> F:
@functools.wraps(func)
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
request = _extract_request(args, kwargs)
if request is None:
return await func(*args, **kwargs)
limiter = get_limiter()
result = await limiter.hit(request, config)
response = await func(*args, **kwargs)
if config.include_headers and hasattr(response, "headers"):
for key, value in result.info.to_headers().items():
response.headers[key] = value
return response
@functools.wraps(func)
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
import asyncio
return asyncio.get_event_loop().run_until_complete(
async_wrapper(*args, **kwargs)
)
if _is_coroutine_function(func):
return async_wrapper # type: ignore[return-value]
return sync_wrapper # type: ignore[return-value]
return decorator
def _extract_request(
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> Request | None:
"""Extract the Request object from function arguments."""
from starlette.requests import Request
for arg in args:
if isinstance(arg, Request):
return arg
for value in kwargs.values():
if isinstance(value, Request):
return value
if "request" in kwargs:
req = kwargs["request"]
if isinstance(req, Request):
return req
return None
def _is_coroutine_function(func: Callable[..., Any]) -> bool:
"""Check if a function is a coroutine function."""
import asyncio
import inspect
return asyncio.iscoroutinefunction(func) or inspect.iscoroutinefunction(func)
class RateLimitDependency:
"""FastAPI dependency for rate limiting.
Example:
```python
from fastapi import FastAPI, Depends
from fastapi_traffic import RateLimitDependency
app = FastAPI()
rate_limiter = RateLimitDependency(limit=100, window_size=60)
@app.get("/api/resource")
async def get_resource(rate_limit_info = Depends(rate_limiter)):
return {"remaining": rate_limit_info.remaining}
```
"""
__slots__ = ("_config",)
def __init__(
self,
limit: int,
window_size: float = 60.0,
*,
algorithm: Algorithm = Algorithm.SLIDING_WINDOW_COUNTER,
key_prefix: str = "ratelimit",
key_extractor: KeyExtractor = default_key_extractor,
burst_size: int | None = None,
error_message: str = "Rate limit exceeded",
status_code: int = 429,
skip_on_error: bool = False,
cost: int = 1,
exempt_when: Callable[[Request], bool] | None = None,
) -> None:
self._config = RateLimitConfig(
limit=limit,
window_size=window_size,
algorithm=algorithm,
key_prefix=key_prefix,
key_extractor=key_extractor,
burst_size=burst_size,
include_headers=True,
error_message=error_message,
status_code=status_code,
skip_on_error=skip_on_error,
cost=cost,
exempt_when=exempt_when,
)
async def __call__(self, request: Request) -> Any:
"""Check rate limit and return info."""
limiter = get_limiter()
result = await limiter.hit(request, self._config)
return result.info
def create_rate_limit_response(
exc: RateLimitExceeded,
*,
include_headers: bool = True,
) -> Response:
"""Create a rate limit exceeded response.
Args:
exc: The RateLimitExceeded exception.
include_headers: Whether to include rate limit headers.
Returns:
A JSONResponse with rate limit information.
"""
from starlette.responses import JSONResponse
headers: dict[str, str] = {}
if include_headers and exc.limit_info is not None:
headers = exc.limit_info.to_headers()
return JSONResponse(
status_code=429,
content={
"detail": exc.message,
"retry_after": exc.retry_after,
},
headers=headers,
)

View File

@@ -0,0 +1,301 @@
"""Core rate limiter implementation."""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
from fastapi_traffic.backends.memory import MemoryBackend
from fastapi_traffic.core.algorithms import Algorithm, BaseAlgorithm, get_algorithm
from fastapi_traffic.core.config import GlobalConfig, RateLimitConfig
from fastapi_traffic.core.models import RateLimitInfo, RateLimitResult
from fastapi_traffic.exceptions import BackendError, RateLimitExceeded
if TYPE_CHECKING:
from starlette.requests import Request
from fastapi_traffic.backends.base import Backend
logger = logging.getLogger(__name__)
class RateLimiter:
"""Main rate limiter class that manages rate limiting logic."""
__slots__ = ("_config", "_backend", "_algorithms", "_initialized")
def __init__(
self,
backend: Backend | None = None,
*,
config: GlobalConfig | None = None,
) -> None:
"""Initialize the rate limiter.
Args:
backend: Storage backend for rate limit data.
config: Global configuration options.
"""
self._config = config or GlobalConfig()
self._backend = backend or self._config.backend or MemoryBackend()
self._algorithms: dict[str, BaseAlgorithm] = {}
self._initialized = False
@property
def backend(self) -> Backend:
"""Get the storage backend."""
return self._backend
@property
def config(self) -> GlobalConfig:
"""Get the global configuration."""
return self._config
async def initialize(self) -> None:
"""Initialize the rate limiter and backend."""
if self._initialized:
return
if hasattr(self._backend, "initialize"):
await self._backend.initialize() # type: ignore[union-attr]
if hasattr(self._backend, "start_cleanup"):
await self._backend.start_cleanup() # type: ignore[union-attr]
self._initialized = True
async def close(self) -> None:
"""Close the rate limiter and cleanup resources."""
await self._backend.close()
self._algorithms.clear()
self._initialized = False
def _get_algorithm(
self,
limit: int,
window_size: float,
algorithm: Algorithm,
burst_size: int | None = None,
) -> BaseAlgorithm:
"""Get or create an algorithm instance."""
cache_key = f"{algorithm.value}:{limit}:{window_size}:{burst_size}"
if cache_key not in self._algorithms:
self._algorithms[cache_key] = get_algorithm(
algorithm,
limit,
window_size,
self._backend,
burst_size=burst_size,
)
return self._algorithms[cache_key]
def _build_key(
self,
request: Request,
config: RateLimitConfig,
identifier: str | None = None,
) -> str:
"""Build the rate limit key for a request."""
if identifier:
client_id = identifier
else:
client_id = config.key_extractor(request)
path = request.url.path
method = request.method
return f"{self._config.key_prefix}:{config.key_prefix}:{method}:{path}:{client_id}"
def _is_exempt(self, request: Request, config: RateLimitConfig) -> bool:
"""Check if the request is exempt from rate limiting."""
if not self._config.enabled:
return True
if config.exempt_when is not None and config.exempt_when(request):
return True
client_ip = config.key_extractor(request)
if client_ip in self._config.exempt_ips:
return True
if request.url.path in self._config.exempt_paths:
return True
return False
async def check(
self,
request: Request,
config: RateLimitConfig,
*,
identifier: str | None = None,
cost: int | None = None,
) -> RateLimitResult:
"""Check if a request is allowed under the rate limit.
Args:
request: The incoming request.
config: Rate limit configuration for this endpoint.
identifier: Optional custom identifier override.
cost: Optional cost override for this request.
Returns:
RateLimitResult with allowed status and limit info.
"""
if not self._initialized:
await self.initialize()
if self._is_exempt(request, config):
return RateLimitResult(
allowed=True,
info=RateLimitInfo(
limit=config.limit,
remaining=config.limit,
reset_at=0,
window_size=config.window_size,
),
key="exempt",
)
key = self._build_key(request, config, identifier)
actual_cost = cost or config.cost
try:
algorithm = self._get_algorithm(
config.limit,
config.window_size,
config.algorithm,
config.burst_size,
)
info: RateLimitInfo | None = None
for _ in range(actual_cost):
allowed, info = await algorithm.check(key)
if not allowed:
return RateLimitResult(allowed=False, info=info, key=key)
if info is None:
info = RateLimitInfo(
limit=config.limit,
remaining=config.limit,
reset_at=0,
window_size=config.window_size,
)
return RateLimitResult(allowed=True, info=info, key=key)
except BackendError as e:
logger.warning("Backend error during rate limit check: %s", e)
if config.skip_on_error:
return RateLimitResult(
allowed=True,
info=RateLimitInfo(
limit=config.limit,
remaining=config.limit,
reset_at=0,
window_size=config.window_size,
),
key=key,
)
raise
async def hit(
self,
request: Request,
config: RateLimitConfig,
*,
identifier: str | None = None,
cost: int | None = None,
) -> RateLimitResult:
"""Check rate limit and raise exception if exceeded.
Args:
request: The incoming request.
config: Rate limit configuration for this endpoint.
identifier: Optional custom identifier override.
cost: Optional cost override for this request.
Returns:
RateLimitResult if allowed.
Raises:
RateLimitExceeded: If the rate limit is exceeded.
"""
result = await self.check(request, config, identifier=identifier, cost=cost)
if not result.allowed:
if config.on_blocked is not None:
config.on_blocked(request, result)
raise RateLimitExceeded(
config.error_message,
retry_after=result.info.retry_after,
limit_info=result.info,
)
return result
async def reset(
self,
request: Request,
config: RateLimitConfig,
*,
identifier: str | None = None,
) -> None:
"""Reset the rate limit for a specific key.
Args:
request: The request to reset limits for.
config: Rate limit configuration.
identifier: Optional custom identifier override.
"""
key = self._build_key(request, config, identifier)
algorithm = self._get_algorithm(
config.limit,
config.window_size,
config.algorithm,
config.burst_size,
)
await algorithm.reset(key)
async def get_state(
self,
request: Request,
config: RateLimitConfig,
*,
identifier: str | None = None,
) -> RateLimitInfo | None:
"""Get the current rate limit state without consuming a token.
Args:
request: The request to check.
config: Rate limit configuration.
identifier: Optional custom identifier override.
Returns:
Current rate limit info or None if no state exists.
"""
key = self._build_key(request, config, identifier)
algorithm = self._get_algorithm(
config.limit,
config.window_size,
config.algorithm,
config.burst_size,
)
return await algorithm.get_state(key)
_default_limiter: RateLimiter | None = None
def get_limiter() -> RateLimiter:
"""Get the default rate limiter instance."""
global _default_limiter
if _default_limiter is None:
_default_limiter = RateLimiter()
return _default_limiter
def set_limiter(limiter: RateLimiter) -> None:
"""Set the default rate limiter instance."""
global _default_limiter
_default_limiter = limiter

View File

@@ -0,0 +1,89 @@
"""Data models for rate limiting."""
from __future__ import annotations
from dataclasses import dataclass, field
from enum import Enum
from typing import Any
class KeyType(str, Enum):
"""Type of key extraction for rate limiting."""
IP = "ip"
USER = "user"
API_KEY = "api_key"
ENDPOINT = "endpoint"
CUSTOM = "custom"
@dataclass(frozen=True, slots=True)
class RateLimitInfo:
"""Information about the current rate limit state."""
limit: int
remaining: int
reset_at: float
retry_after: float | None = None
window_size: float = 60.0
def to_headers(self) -> dict[str, str]:
"""Convert rate limit info to HTTP headers."""
headers: dict[str, str] = {
"X-RateLimit-Limit": str(self.limit),
"X-RateLimit-Remaining": str(max(0, self.remaining)),
"X-RateLimit-Reset": str(int(self.reset_at)),
}
if self.retry_after is not None:
headers["Retry-After"] = str(int(self.retry_after))
return headers
@dataclass(frozen=True, slots=True)
class RateLimitResult:
"""Result of a rate limit check."""
allowed: bool
info: RateLimitInfo
key: str
@dataclass(slots=True)
class TokenBucketState:
"""State for token bucket algorithm."""
tokens: float
last_update: float
@dataclass(slots=True)
class SlidingWindowState:
"""State for sliding window algorithm."""
timestamps: list[float] = field(default_factory=list)
count: int = 0
@dataclass(slots=True)
class FixedWindowState:
"""State for fixed window algorithm."""
count: int
window_start: float
@dataclass(slots=True)
class LeakyBucketState:
"""State for leaky bucket algorithm."""
water_level: float
last_update: float
@dataclass(frozen=True, slots=True)
class BackendRecord:
"""Generic record stored in backends."""
key: str
data: dict[str, Any]
expires_at: float