"""Configuration for rate limiting.""" from __future__ import annotations from collections.abc import Callable from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any from fastapi_traffic.core.algorithms import Algorithm if TYPE_CHECKING: from starlette.requests import Request from fastapi_traffic.backends.base import Backend KeyExtractor = Callable[["Request"], str] def default_key_extractor(request: Request) -> str: """Extract client IP as the default rate limit key.""" forwarded = request.headers.get("X-Forwarded-For") if forwarded: return forwarded.split(",")[0].strip() real_ip = request.headers.get("X-Real-IP") if real_ip: return real_ip if request.client: return request.client.host return "unknown" @dataclass(slots=True) class RateLimitConfig: """Configuration for a rate limit rule.""" limit: int window_size: float = 60.0 algorithm: Algorithm = Algorithm.SLIDING_WINDOW_COUNTER key_prefix: str = "ratelimit" key_extractor: KeyExtractor = field(default=default_key_extractor) burst_size: int | None = None include_headers: bool = True error_message: str = "Rate limit exceeded" status_code: int = 429 skip_on_error: bool = False cost: int = 1 exempt_when: Callable[[Request], bool] | None = None on_blocked: Callable[[Request, Any], Any] | None = None def __post_init__(self) -> None: if self.limit <= 0: msg = "limit must be positive" raise ValueError(msg) if self.window_size <= 0: msg = "window_size must be positive" raise ValueError(msg) if self.cost <= 0: msg = "cost must be positive" raise ValueError(msg) @dataclass(slots=True) class GlobalConfig: """Global configuration for the rate limiter.""" backend: Backend | None = None enabled: bool = True default_limit: int = 100 default_window_size: float = 60.0 default_algorithm: Algorithm = Algorithm.SLIDING_WINDOW_COUNTER key_prefix: str = "fastapi_traffic" include_headers: bool = True error_message: str = "Rate limit exceeded. Please try again later." status_code: int = 429 skip_on_error: bool = False exempt_ips: set[str] = field(default_factory=set) exempt_paths: set[str] = field(default_factory=set) headers_prefix: str = "X-RateLimit"