"""Core rate limiter implementation.""" from __future__ import annotations import logging from typing import TYPE_CHECKING from fastapi_traffic.backends.memory import MemoryBackend from fastapi_traffic.core.algorithms import Algorithm, BaseAlgorithm, get_algorithm from fastapi_traffic.core.config import GlobalConfig, RateLimitConfig from fastapi_traffic.core.models import RateLimitInfo, RateLimitResult from fastapi_traffic.exceptions import BackendError, RateLimitExceeded if TYPE_CHECKING: from starlette.requests import Request from fastapi_traffic.backends.base import Backend logger = logging.getLogger(__name__) class RateLimiter: """Main rate limiter class that manages rate limiting logic.""" __slots__ = ("_config", "_backend", "_algorithms", "_initialized") def __init__( self, backend: Backend | None = None, *, config: GlobalConfig | None = None, ) -> None: """Initialize the rate limiter. Args: backend: Storage backend for rate limit data. config: Global configuration options. """ self._config = config or GlobalConfig() self._backend = backend or self._config.backend or MemoryBackend() self._algorithms: dict[str, BaseAlgorithm] = {} self._initialized = False @property def backend(self) -> Backend: """Get the storage backend.""" return self._backend @property def config(self) -> GlobalConfig: """Get the global configuration.""" return self._config async def initialize(self) -> None: """Initialize the rate limiter and backend.""" if self._initialized: return if hasattr(self._backend, "initialize"): await self._backend.initialize() # type: ignore[union-attr] if hasattr(self._backend, "start_cleanup"): await self._backend.start_cleanup() # type: ignore[union-attr] self._initialized = True async def close(self) -> None: """Close the rate limiter and cleanup resources.""" await self._backend.close() self._algorithms.clear() self._initialized = False def _get_algorithm( self, limit: int, window_size: float, algorithm: Algorithm, burst_size: int | None = None, ) -> BaseAlgorithm: """Get or create an algorithm instance.""" cache_key = f"{algorithm.value}:{limit}:{window_size}:{burst_size}" if cache_key not in self._algorithms: self._algorithms[cache_key] = get_algorithm( algorithm, limit, window_size, self._backend, burst_size=burst_size, ) return self._algorithms[cache_key] def _build_key( self, request: Request, config: RateLimitConfig, identifier: str | None = None, ) -> str: """Build the rate limit key for a request.""" if identifier: client_id = identifier else: client_id = config.key_extractor(request) path = request.url.path method = request.method return f"{self._config.key_prefix}:{config.key_prefix}:{method}:{path}:{client_id}" def _is_exempt(self, request: Request, config: RateLimitConfig) -> bool: """Check if the request is exempt from rate limiting.""" if not self._config.enabled: return True if config.exempt_when is not None and config.exempt_when(request): return True client_ip = config.key_extractor(request) if client_ip in self._config.exempt_ips: return True if request.url.path in self._config.exempt_paths: return True return False async def check( self, request: Request, config: RateLimitConfig, *, identifier: str | None = None, cost: int | None = None, ) -> RateLimitResult: """Check if a request is allowed under the rate limit. Args: request: The incoming request. config: Rate limit configuration for this endpoint. identifier: Optional custom identifier override. cost: Optional cost override for this request. Returns: RateLimitResult with allowed status and limit info. """ if not self._initialized: await self.initialize() if self._is_exempt(request, config): return RateLimitResult( allowed=True, info=RateLimitInfo( limit=config.limit, remaining=config.limit, reset_at=0, window_size=config.window_size, ), key="exempt", ) key = self._build_key(request, config, identifier) actual_cost = cost or config.cost try: algorithm = self._get_algorithm( config.limit, config.window_size, config.algorithm, config.burst_size, ) info: RateLimitInfo | None = None for _ in range(actual_cost): allowed, info = await algorithm.check(key) if not allowed: return RateLimitResult(allowed=False, info=info, key=key) if info is None: info = RateLimitInfo( limit=config.limit, remaining=config.limit, reset_at=0, window_size=config.window_size, ) return RateLimitResult(allowed=True, info=info, key=key) except BackendError as e: logger.warning("Backend error during rate limit check: %s", e) if config.skip_on_error: return RateLimitResult( allowed=True, info=RateLimitInfo( limit=config.limit, remaining=config.limit, reset_at=0, window_size=config.window_size, ), key=key, ) raise async def hit( self, request: Request, config: RateLimitConfig, *, identifier: str | None = None, cost: int | None = None, ) -> RateLimitResult: """Check rate limit and raise exception if exceeded. Args: request: The incoming request. config: Rate limit configuration for this endpoint. identifier: Optional custom identifier override. cost: Optional cost override for this request. Returns: RateLimitResult if allowed. Raises: RateLimitExceeded: If the rate limit is exceeded. """ result = await self.check(request, config, identifier=identifier, cost=cost) if not result.allowed: if config.on_blocked is not None: config.on_blocked(request, result) raise RateLimitExceeded( config.error_message, retry_after=result.info.retry_after, limit_info=result.info, ) return result async def reset( self, request: Request, config: RateLimitConfig, *, identifier: str | None = None, ) -> None: """Reset the rate limit for a specific key. Args: request: The request to reset limits for. config: Rate limit configuration. identifier: Optional custom identifier override. """ key = self._build_key(request, config, identifier) algorithm = self._get_algorithm( config.limit, config.window_size, config.algorithm, config.burst_size, ) await algorithm.reset(key) async def get_state( self, request: Request, config: RateLimitConfig, *, identifier: str | None = None, ) -> RateLimitInfo | None: """Get the current rate limit state without consuming a token. Args: request: The request to check. config: Rate limit configuration. identifier: Optional custom identifier override. Returns: Current rate limit info or None if no state exists. """ key = self._build_key(request, config, identifier) algorithm = self._get_algorithm( config.limit, config.window_size, config.algorithm, config.burst_size, ) return await algorithm.get_state(key) _default_limiter: RateLimiter | None = None def get_limiter() -> RateLimiter: """Get the default rate limiter instance.""" global _default_limiter if _default_limiter is None: _default_limiter = RateLimiter() return _default_limiter def set_limiter(limiter: RateLimiter) -> None: """Set the default rate limiter instance.""" global _default_limiter _default_limiter = limiter