Initial commit: fastapi-traffic rate limiting library
- Core rate limiting with multiple algorithms (sliding window, token bucket, etc.) - SQLite and memory backends - Decorator and dependency injection patterns - Middleware support - Example usage files
This commit is contained in:
301
fastapi_traffic/core/limiter.py
Normal file
301
fastapi_traffic/core/limiter.py
Normal file
@@ -0,0 +1,301 @@
|
||||
"""Core rate limiter implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi_traffic.backends.memory import MemoryBackend
|
||||
from fastapi_traffic.core.algorithms import Algorithm, BaseAlgorithm, get_algorithm
|
||||
from fastapi_traffic.core.config import GlobalConfig, RateLimitConfig
|
||||
from fastapi_traffic.core.models import RateLimitInfo, RateLimitResult
|
||||
from fastapi_traffic.exceptions import BackendError, RateLimitExceeded
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from starlette.requests import Request
|
||||
|
||||
from fastapi_traffic.backends.base import Backend
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Main rate limiter class that manages rate limiting logic."""
|
||||
|
||||
__slots__ = ("_config", "_backend", "_algorithms", "_initialized")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backend: Backend | None = None,
|
||||
*,
|
||||
config: GlobalConfig | None = None,
|
||||
) -> None:
|
||||
"""Initialize the rate limiter.
|
||||
|
||||
Args:
|
||||
backend: Storage backend for rate limit data.
|
||||
config: Global configuration options.
|
||||
"""
|
||||
self._config = config or GlobalConfig()
|
||||
self._backend = backend or self._config.backend or MemoryBackend()
|
||||
self._algorithms: dict[str, BaseAlgorithm] = {}
|
||||
self._initialized = False
|
||||
|
||||
@property
|
||||
def backend(self) -> Backend:
|
||||
"""Get the storage backend."""
|
||||
return self._backend
|
||||
|
||||
@property
|
||||
def config(self) -> GlobalConfig:
|
||||
"""Get the global configuration."""
|
||||
return self._config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the rate limiter and backend."""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
if hasattr(self._backend, "initialize"):
|
||||
await self._backend.initialize() # type: ignore[union-attr]
|
||||
|
||||
if hasattr(self._backend, "start_cleanup"):
|
||||
await self._backend.start_cleanup() # type: ignore[union-attr]
|
||||
|
||||
self._initialized = True
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the rate limiter and cleanup resources."""
|
||||
await self._backend.close()
|
||||
self._algorithms.clear()
|
||||
self._initialized = False
|
||||
|
||||
def _get_algorithm(
|
||||
self,
|
||||
limit: int,
|
||||
window_size: float,
|
||||
algorithm: Algorithm,
|
||||
burst_size: int | None = None,
|
||||
) -> BaseAlgorithm:
|
||||
"""Get or create an algorithm instance."""
|
||||
cache_key = f"{algorithm.value}:{limit}:{window_size}:{burst_size}"
|
||||
if cache_key not in self._algorithms:
|
||||
self._algorithms[cache_key] = get_algorithm(
|
||||
algorithm,
|
||||
limit,
|
||||
window_size,
|
||||
self._backend,
|
||||
burst_size=burst_size,
|
||||
)
|
||||
return self._algorithms[cache_key]
|
||||
|
||||
def _build_key(
|
||||
self,
|
||||
request: Request,
|
||||
config: RateLimitConfig,
|
||||
identifier: str | None = None,
|
||||
) -> str:
|
||||
"""Build the rate limit key for a request."""
|
||||
if identifier:
|
||||
client_id = identifier
|
||||
else:
|
||||
client_id = config.key_extractor(request)
|
||||
|
||||
path = request.url.path
|
||||
method = request.method
|
||||
|
||||
return f"{self._config.key_prefix}:{config.key_prefix}:{method}:{path}:{client_id}"
|
||||
|
||||
def _is_exempt(self, request: Request, config: RateLimitConfig) -> bool:
|
||||
"""Check if the request is exempt from rate limiting."""
|
||||
if not self._config.enabled:
|
||||
return True
|
||||
|
||||
if config.exempt_when is not None and config.exempt_when(request):
|
||||
return True
|
||||
|
||||
client_ip = config.key_extractor(request)
|
||||
if client_ip in self._config.exempt_ips:
|
||||
return True
|
||||
|
||||
if request.url.path in self._config.exempt_paths:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def check(
|
||||
self,
|
||||
request: Request,
|
||||
config: RateLimitConfig,
|
||||
*,
|
||||
identifier: str | None = None,
|
||||
cost: int | None = None,
|
||||
) -> RateLimitResult:
|
||||
"""Check if a request is allowed under the rate limit.
|
||||
|
||||
Args:
|
||||
request: The incoming request.
|
||||
config: Rate limit configuration for this endpoint.
|
||||
identifier: Optional custom identifier override.
|
||||
cost: Optional cost override for this request.
|
||||
|
||||
Returns:
|
||||
RateLimitResult with allowed status and limit info.
|
||||
"""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
if self._is_exempt(request, config):
|
||||
return RateLimitResult(
|
||||
allowed=True,
|
||||
info=RateLimitInfo(
|
||||
limit=config.limit,
|
||||
remaining=config.limit,
|
||||
reset_at=0,
|
||||
window_size=config.window_size,
|
||||
),
|
||||
key="exempt",
|
||||
)
|
||||
|
||||
key = self._build_key(request, config, identifier)
|
||||
actual_cost = cost or config.cost
|
||||
|
||||
try:
|
||||
algorithm = self._get_algorithm(
|
||||
config.limit,
|
||||
config.window_size,
|
||||
config.algorithm,
|
||||
config.burst_size,
|
||||
)
|
||||
|
||||
info: RateLimitInfo | None = None
|
||||
for _ in range(actual_cost):
|
||||
allowed, info = await algorithm.check(key)
|
||||
if not allowed:
|
||||
return RateLimitResult(allowed=False, info=info, key=key)
|
||||
|
||||
if info is None:
|
||||
info = RateLimitInfo(
|
||||
limit=config.limit,
|
||||
remaining=config.limit,
|
||||
reset_at=0,
|
||||
window_size=config.window_size,
|
||||
)
|
||||
return RateLimitResult(allowed=True, info=info, key=key)
|
||||
|
||||
except BackendError as e:
|
||||
logger.warning("Backend error during rate limit check: %s", e)
|
||||
if config.skip_on_error:
|
||||
return RateLimitResult(
|
||||
allowed=True,
|
||||
info=RateLimitInfo(
|
||||
limit=config.limit,
|
||||
remaining=config.limit,
|
||||
reset_at=0,
|
||||
window_size=config.window_size,
|
||||
),
|
||||
key=key,
|
||||
)
|
||||
raise
|
||||
|
||||
async def hit(
|
||||
self,
|
||||
request: Request,
|
||||
config: RateLimitConfig,
|
||||
*,
|
||||
identifier: str | None = None,
|
||||
cost: int | None = None,
|
||||
) -> RateLimitResult:
|
||||
"""Check rate limit and raise exception if exceeded.
|
||||
|
||||
Args:
|
||||
request: The incoming request.
|
||||
config: Rate limit configuration for this endpoint.
|
||||
identifier: Optional custom identifier override.
|
||||
cost: Optional cost override for this request.
|
||||
|
||||
Returns:
|
||||
RateLimitResult if allowed.
|
||||
|
||||
Raises:
|
||||
RateLimitExceeded: If the rate limit is exceeded.
|
||||
"""
|
||||
result = await self.check(request, config, identifier=identifier, cost=cost)
|
||||
|
||||
if not result.allowed:
|
||||
if config.on_blocked is not None:
|
||||
config.on_blocked(request, result)
|
||||
|
||||
raise RateLimitExceeded(
|
||||
config.error_message,
|
||||
retry_after=result.info.retry_after,
|
||||
limit_info=result.info,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def reset(
|
||||
self,
|
||||
request: Request,
|
||||
config: RateLimitConfig,
|
||||
*,
|
||||
identifier: str | None = None,
|
||||
) -> None:
|
||||
"""Reset the rate limit for a specific key.
|
||||
|
||||
Args:
|
||||
request: The request to reset limits for.
|
||||
config: Rate limit configuration.
|
||||
identifier: Optional custom identifier override.
|
||||
"""
|
||||
key = self._build_key(request, config, identifier)
|
||||
algorithm = self._get_algorithm(
|
||||
config.limit,
|
||||
config.window_size,
|
||||
config.algorithm,
|
||||
config.burst_size,
|
||||
)
|
||||
await algorithm.reset(key)
|
||||
|
||||
async def get_state(
|
||||
self,
|
||||
request: Request,
|
||||
config: RateLimitConfig,
|
||||
*,
|
||||
identifier: str | None = None,
|
||||
) -> RateLimitInfo | None:
|
||||
"""Get the current rate limit state without consuming a token.
|
||||
|
||||
Args:
|
||||
request: The request to check.
|
||||
config: Rate limit configuration.
|
||||
identifier: Optional custom identifier override.
|
||||
|
||||
Returns:
|
||||
Current rate limit info or None if no state exists.
|
||||
"""
|
||||
key = self._build_key(request, config, identifier)
|
||||
algorithm = self._get_algorithm(
|
||||
config.limit,
|
||||
config.window_size,
|
||||
config.algorithm,
|
||||
config.burst_size,
|
||||
)
|
||||
return await algorithm.get_state(key)
|
||||
|
||||
|
||||
_default_limiter: RateLimiter | None = None
|
||||
|
||||
|
||||
def get_limiter() -> RateLimiter:
|
||||
"""Get the default rate limiter instance."""
|
||||
global _default_limiter
|
||||
if _default_limiter is None:
|
||||
_default_limiter = RateLimiter()
|
||||
return _default_limiter
|
||||
|
||||
|
||||
def set_limiter(limiter: RateLimiter) -> None:
|
||||
"""Set the default rate limiter instance."""
|
||||
global _default_limiter
|
||||
_default_limiter = limiter
|
||||
Reference in New Issue
Block a user