"""Configuration loader for rate limiting settings from .env and .json files.""" from __future__ import annotations import json import os from pathlib import Path from typing import TYPE_CHECKING, Any, TypeVar, cast from pydantic import BaseModel, ConfigDict, ValidationError, field_validator from fastapi_traffic.core.algorithms import Algorithm from fastapi_traffic.core.config import GlobalConfig, RateLimitConfig from fastapi_traffic.exceptions import ConfigurationError if TYPE_CHECKING: from collections.abc import Mapping T = TypeVar("T", RateLimitConfig, GlobalConfig) # Environment variable prefix for config values ENV_PREFIX = "FASTAPI_TRAFFIC_" # Fields that cannot be loaded from config files (callables, complex objects) _NON_LOADABLE_FIELDS: frozenset[str] = frozenset( { "key_extractor", "exempt_when", "on_blocked", "backend", } ) class _RateLimitSchema(BaseModel): """Pydantic schema for validating rate limit configuration input.""" model_config = ConfigDict(extra="forbid") limit: int window_size: float = 60.0 algorithm: Algorithm = Algorithm.SLIDING_WINDOW_COUNTER key_prefix: str = "ratelimit" burst_size: int | None = None include_headers: bool = True error_message: str = "Rate limit exceeded" status_code: int = 429 skip_on_error: bool = False cost: int = 1 @field_validator("algorithm", mode="before") @classmethod def _normalize_algorithm(cls, v: Any) -> Any: if isinstance(v, str): return v.lower() return v class _GlobalSchema(BaseModel): """Pydantic schema for validating global configuration input.""" model_config = ConfigDict(extra="forbid") enabled: bool = True default_limit: int = 100 default_window_size: float = 60.0 default_algorithm: Algorithm = Algorithm.SLIDING_WINDOW_COUNTER key_prefix: str = "fastapi_traffic" include_headers: bool = True error_message: str = "Rate limit exceeded. Please try again later." status_code: int = 429 skip_on_error: bool = False exempt_ips: set[str] = set() exempt_paths: set[str] = set() headers_prefix: str = "X-RateLimit" @field_validator("default_algorithm", mode="before") @classmethod def _normalize_algorithm(cls, v: Any) -> Any: if isinstance(v, str): return v.lower() return v # Known field names per schema (used for env-var extraction) _RATE_LIMIT_FIELDS: frozenset[str] = frozenset(_RateLimitSchema.model_fields.keys()) _GLOBAL_FIELDS: frozenset[str] = frozenset(_GlobalSchema.model_fields.keys()) def _check_non_loadable(data: Mapping[str, Any]) -> None: """Raise ConfigurationError if data contains non-loadable fields.""" for key in data: if key in _NON_LOADABLE_FIELDS: msg = f"Field '{key}' cannot be loaded from configuration files" raise ConfigurationError(msg) def _format_validation_error(exc: ValidationError) -> str: """Convert a Pydantic ValidationError to a user-friendly message.""" errors = exc.errors() if not errors: return str(exc) err = errors[0] loc = ".".join(str(p) for p in err["loc"]) if err["loc"] else "unknown" err_type = err["type"] msg = err["msg"] ctx = err.get("ctx", {}) if err_type == "extra_forbidden": return f"Unknown configuration field: '{loc}'" if err_type in ("int_parsing", "float_parsing"): input_val = ctx.get("error", err.get("input", "")) return f"Cannot parse value '{input_val}' as {loc}: {msg}" if err_type == "bool_parsing": return f"Cannot parse value as bool for '{loc}': {msg}" if "enum" in err_type or err_type == "value_error": input_val = err.get("input", "") return f"Cannot parse value '{input_val}' as {loc}: {msg}" return f"Invalid value for '{loc}': {msg}" class ConfigLoader: """Loader for rate limiting configuration from various sources. Supports loading configuration from: - Environment variables (with FASTAPI_TRAFFIC_ prefix) - .env files - JSON files Example usage: >>> loader = ConfigLoader() >>> global_config = loader.load_global_config_from_env() >>> rate_config = loader.load_rate_limit_config_from_json("config.json") """ __slots__ = ("_env_prefix",) def __init__(self, env_prefix: str = ENV_PREFIX) -> None: """Initialize the config loader. Args: env_prefix: Prefix for environment variables. Defaults to "FASTAPI_TRAFFIC_". """ self._env_prefix = env_prefix def _load_dotenv_file(self, file_path: Path) -> dict[str, str]: """Load environment variables from a .env file. Args: file_path: Path to the .env file. Returns: Dictionary of environment variable names to values. Raises: ConfigurationError: If the file cannot be read or parsed. """ if not file_path.exists(): msg = f"Configuration file not found: {file_path}" raise ConfigurationError(msg) env_vars: dict[str, str] = {} try: with file_path.open(encoding="utf-8") as f: for line_num, line in enumerate(f, start=1): line = line.strip() # Skip empty lines and comments if not line or line.startswith("#"): continue # Parse key=value pairs if "=" not in line: msg = f"Invalid line {line_num} in {file_path}: missing '='" raise ConfigurationError(msg) key, _, value = line.partition("=") key = key.strip() value = value.strip() # Remove surrounding quotes if present if ( len(value) >= 2 and value[0] == value[-1] and value[0] in ('"', "'") ): value = value[1:-1] env_vars[key] = value except OSError as e: msg = f"Failed to read configuration file {file_path}: {e}" raise ConfigurationError(msg) from e return env_vars def _load_json_file(self, file_path: Path) -> Any: """Load configuration from a JSON file. Args: file_path: Path to the JSON file. Returns: Parsed JSON data (could be any JSON type). Raises: ConfigurationError: If the file cannot be read or parsed. """ if not file_path.exists(): msg = f"Configuration file not found: {file_path}" raise ConfigurationError(msg) try: with file_path.open(encoding="utf-8") as f: data: Any = json.load(f) except json.JSONDecodeError as e: msg = f"Invalid JSON in {file_path}: {e}" raise ConfigurationError(msg) from e except OSError as e: msg = f"Failed to read configuration file {file_path}: {e}" raise ConfigurationError(msg) from e return data def _extract_env_config( self, prefix: str, known_fields: frozenset[str], env_source: Mapping[str, str] | None = None, ) -> dict[str, str]: """Extract configuration from environment variables. Args: prefix: The prefix to look for (e.g., "RATE_LIMIT_" or "GLOBAL_"). known_fields: Set of known field names. env_source: Optional source of environment variables. Defaults to os.environ. Returns: Dictionary of field names to their string values. """ source = env_source if env_source is not None else os.environ full_prefix = f"{self._env_prefix}{prefix}" result: dict[str, str] = {} for key, value in source.items(): if key.startswith(full_prefix): field_name = key[len(full_prefix) :].lower() if field_name in known_fields: result[field_name] = value return result def _parse_set_from_string(self, data: dict[str, Any]) -> dict[str, Any]: """Pre-process comma-separated string values into lists for set fields. This handles the env-var case where sets are represented as comma-separated strings (e.g., "127.0.0.1, 10.0.0.1"). """ result = dict(data) for key in ("exempt_ips", "exempt_paths"): if key in result and isinstance(result[key], str): value = result[key].strip() if not value: result[key] = [] else: result[key] = [ item.strip() for item in value.split(",") if item.strip() ] return result def load_rate_limit_config_from_env( self, env_source: Mapping[str, str] | None = None, **overrides: Any, ) -> RateLimitConfig: """Load RateLimitConfig from environment variables. Environment variables should be prefixed with FASTAPI_TRAFFIC_RATE_LIMIT_ (e.g., FASTAPI_TRAFFIC_RATE_LIMIT_LIMIT=100). Args: env_source: Optional source of environment variables. Defaults to os.environ. **overrides: Additional values to override loaded config. Returns: Configured RateLimitConfig instance. Raises: ConfigurationError: If configuration is invalid. """ raw_config = self._extract_env_config( "RATE_LIMIT_", _RATE_LIMIT_FIELDS, env_source ) _check_non_loadable(raw_config) # Merge loadable overrides before validation so required fields can be supplied non_loadable_overrides: dict[str, Any] = {} for key, value in overrides.items(): if key in _NON_LOADABLE_FIELDS: non_loadable_overrides[key] = value elif key in _RATE_LIMIT_FIELDS: raw_config[key] = value try: schema = _RateLimitSchema(**raw_config) # type: ignore[arg-type] # Pydantic coerces str→typed values at runtime except ValidationError as e: raise ConfigurationError(_format_validation_error(e)) from e config_dict = schema.model_dump(exclude_defaults=True) # Apply non-loadable overrides (callables, etc.) config_dict.update(non_loadable_overrides) # Ensure required field 'limit' is present if "limit" not in config_dict: msg = "Required field 'limit' not found in environment configuration" raise ConfigurationError(msg) return RateLimitConfig(**config_dict) def load_rate_limit_config_from_dotenv( self, file_path: str | Path, **overrides: Any, ) -> RateLimitConfig: """Load RateLimitConfig from a .env file. Args: file_path: Path to the .env file. **overrides: Additional values to override loaded config. Returns: Configured RateLimitConfig instance. Raises: ConfigurationError: If configuration is invalid. """ path = Path(file_path) env_vars = self._load_dotenv_file(path) return self.load_rate_limit_config_from_env(env_vars, **overrides) def load_rate_limit_config_from_json( self, file_path: str | Path, **overrides: Any, ) -> RateLimitConfig: """Load RateLimitConfig from a JSON file. Args: file_path: Path to the JSON file. **overrides: Additional values to override loaded config. Returns: Configured RateLimitConfig instance. Raises: ConfigurationError: If configuration is invalid. """ path = Path(file_path) raw_config = self._load_json_file(path) if not isinstance(raw_config, dict): msg = "JSON root must be an object" raise ConfigurationError(msg) _check_non_loadable(cast("dict[str, Any]", raw_config)) # Merge loadable overrides before validation so required fields can be supplied non_loadable_overrides: dict[str, Any] = {} for key, value in overrides.items(): if key in _NON_LOADABLE_FIELDS: non_loadable_overrides[key] = value elif key in _RATE_LIMIT_FIELDS: raw_config[key] = value try: schema = _RateLimitSchema(**raw_config) # type: ignore[arg-type] # Pydantic coerces str→typed values at runtime except ValidationError as e: raise ConfigurationError(_format_validation_error(e)) from e config_dict = schema.model_dump(exclude_defaults=True) # Apply non-loadable overrides (callables, etc.) config_dict.update(non_loadable_overrides) # Ensure required field 'limit' is present if "limit" not in config_dict: msg = "Required field 'limit' not found in JSON configuration" raise ConfigurationError(msg) return RateLimitConfig(**config_dict) def load_global_config_from_env( self, env_source: Mapping[str, str] | None = None, **overrides: Any, ) -> GlobalConfig: """Load GlobalConfig from environment variables. Environment variables should be prefixed with FASTAPI_TRAFFIC_GLOBAL_ (e.g., FASTAPI_TRAFFIC_GLOBAL_ENABLED=true). Args: env_source: Optional source of environment variables. Defaults to os.environ. **overrides: Additional values to override loaded config. Returns: Configured GlobalConfig instance. Raises: ConfigurationError: If configuration is invalid. """ raw_config = self._extract_env_config("GLOBAL_", _GLOBAL_FIELDS, env_source) _check_non_loadable(raw_config) # Pre-process comma-separated strings into lists for set fields processed = self._parse_set_from_string(raw_config) try: schema = _GlobalSchema(**processed) # type: ignore[arg-type] # Pydantic coerces str→typed values at runtime except ValidationError as e: raise ConfigurationError(_format_validation_error(e)) from e config_dict = schema.model_dump(exclude_defaults=True) # Apply overrides for key, value in overrides.items(): if key in _NON_LOADABLE_FIELDS or key in _GLOBAL_FIELDS: config_dict[key] = value return GlobalConfig(**config_dict) def load_global_config_from_dotenv( self, file_path: str | Path, **overrides: Any, ) -> GlobalConfig: """Load GlobalConfig from a .env file. Args: file_path: Path to the .env file. **overrides: Additional values to override loaded config. Returns: Configured GlobalConfig instance. Raises: ConfigurationError: If configuration is invalid. """ path = Path(file_path) env_vars = self._load_dotenv_file(path) return self.load_global_config_from_env(env_vars, **overrides) def load_global_config_from_json( self, file_path: str | Path, **overrides: Any, ) -> GlobalConfig: """Load GlobalConfig from a JSON file. Args: file_path: Path to the JSON file. **overrides: Additional values to override loaded config. Returns: Configured GlobalConfig instance. Raises: ConfigurationError: If configuration is invalid. """ path = Path(file_path) raw_config = self._load_json_file(path) if not isinstance(raw_config, dict): msg = "JSON root must be an object" raise ConfigurationError(msg) _check_non_loadable(cast("dict[str, Any]", raw_config)) try: schema = _GlobalSchema(**cast("dict[str, Any]", raw_config)) except ValidationError as e: raise ConfigurationError(_format_validation_error(e)) from e config_dict = schema.model_dump(exclude_defaults=True) # Apply overrides for key, value in overrides.items(): if key in _NON_LOADABLE_FIELDS or key in _GLOBAL_FIELDS: config_dict[key] = value return GlobalConfig(**config_dict) # Convenience functions for direct usage _default_loader: ConfigLoader | None = None def _get_default_loader() -> ConfigLoader: """Get or create the default config loader.""" global _default_loader if _default_loader is None: _default_loader = ConfigLoader() return _default_loader def load_rate_limit_config( source: str | Path, **overrides: Any, ) -> RateLimitConfig: """Load RateLimitConfig from a file (auto-detects format). Args: source: Path to configuration file (.env or .json). **overrides: Additional values to override loaded config. Returns: Configured RateLimitConfig instance. Raises: ConfigurationError: If configuration is invalid or format unknown. """ loader = _get_default_loader() path = Path(source) if path.suffix.lower() == ".json": return loader.load_rate_limit_config_from_json(path, **overrides) if path.suffix.lower() in (".env", "") or path.name.startswith(".env"): return loader.load_rate_limit_config_from_dotenv(path, **overrides) msg = f"Unknown configuration file format: {path.suffix}" raise ConfigurationError(msg) def load_global_config( source: str | Path, **overrides: Any, ) -> GlobalConfig: """Load GlobalConfig from a file (auto-detects format). Args: source: Path to configuration file (.env or .json). **overrides: Additional values to override loaded config. Returns: Configured GlobalConfig instance. Raises: ConfigurationError: If configuration is invalid or format unknown. """ loader = _get_default_loader() path = Path(source) if path.suffix.lower() == ".json": return loader.load_global_config_from_json(path, **overrides) if path.suffix.lower() in (".env", "") or path.name.startswith(".env"): return loader.load_global_config_from_dotenv(path, **overrides) msg = f"Unknown configuration file format: {path.suffix}" raise ConfigurationError(msg) def load_rate_limit_config_from_env(**overrides: Any) -> RateLimitConfig: """Load RateLimitConfig from environment variables. Args: **overrides: Additional values to override loaded config. Returns: Configured RateLimitConfig instance. """ return _get_default_loader().load_rate_limit_config_from_env(**overrides) def load_global_config_from_env(**overrides: Any) -> GlobalConfig: """Load GlobalConfig from environment variables. Args: **overrides: Additional values to override loaded config. Returns: Configured GlobalConfig instance. """ return _get_default_loader().load_global_config_from_env(**overrides)