"""Configuration loader for rate limiting settings from .env and .json files.""" from __future__ import annotations import json import os from pathlib import Path from typing import TYPE_CHECKING, Any, TypeVar from fastapi_traffic.core.algorithms import Algorithm from fastapi_traffic.core.config import GlobalConfig, RateLimitConfig from fastapi_traffic.exceptions import ConfigurationError if TYPE_CHECKING: from collections.abc import Mapping T = TypeVar("T", RateLimitConfig, GlobalConfig) # Environment variable prefix for config values ENV_PREFIX = "FASTAPI_TRAFFIC_" # Mapping of config field names to their types for validation _RATE_LIMIT_FIELD_TYPES: dict[str, type[Any]] = { "limit": int, "window_size": float, "algorithm": Algorithm, "key_prefix": str, "burst_size": int, "include_headers": bool, "error_message": str, "status_code": int, "skip_on_error": bool, "cost": int, } _GLOBAL_FIELD_TYPES: dict[str, type[Any]] = { "enabled": bool, "default_limit": int, "default_window_size": float, "default_algorithm": Algorithm, "key_prefix": str, "include_headers": bool, "error_message": str, "status_code": int, "skip_on_error": bool, "exempt_ips": set, "exempt_paths": set, "headers_prefix": str, } # Fields that cannot be loaded from config files (callables, complex objects) _NON_LOADABLE_FIELDS: frozenset[str] = frozenset( { "key_extractor", "exempt_when", "on_blocked", "backend", } ) class ConfigLoader: """Loader for rate limiting configuration from various sources. Supports loading configuration from: - Environment variables (with FASTAPI_TRAFFIC_ prefix) - .env files - JSON files Example usage: >>> loader = ConfigLoader() >>> global_config = loader.load_global_config_from_env() >>> rate_config = loader.load_rate_limit_config_from_json("config.json") """ __slots__ = ("_env_prefix",) def __init__(self, env_prefix: str = ENV_PREFIX) -> None: """Initialize the config loader. Args: env_prefix: Prefix for environment variables. Defaults to "FASTAPI_TRAFFIC_". """ self._env_prefix = env_prefix def _parse_value(self, value: str, target_type: type[Any]) -> Any: """Parse a string value to the target type. Args: value: The string value to parse. target_type: The target type to convert to. Returns: The parsed value. Raises: ConfigurationError: If the value cannot be parsed. """ try: if target_type is bool: return value.lower() in ("true", "1", "yes", "on") if target_type is int: return int(value) if target_type is float: return float(value) if target_type is str: return value if target_type is Algorithm: return Algorithm(value.lower()) if target_type is set: # Parse comma-separated values if not value.strip(): return set() return {item.strip() for item in value.split(",") if item.strip()} except (ValueError, KeyError) as e: msg = f"Cannot parse value '{value}' as {target_type.__name__}: {e}" raise ConfigurationError(msg) from e msg = f"Unsupported type: {target_type}" raise ConfigurationError(msg) def _validate_and_convert( self, data: Mapping[str, Any], field_types: dict[str, type[Any]], ) -> dict[str, Any]: """Validate and convert configuration data. Args: data: Raw configuration data. field_types: Mapping of field names to their expected types. Returns: Validated and converted configuration dictionary. Raises: ConfigurationError: If validation fails. """ result: dict[str, Any] = {} for key, value in data.items(): if key in _NON_LOADABLE_FIELDS: msg = f"Field '{key}' cannot be loaded from configuration files" raise ConfigurationError(msg) if key not in field_types: msg = f"Unknown configuration field: '{key}'" raise ConfigurationError(msg) target_type = field_types[key] if isinstance(value, str): result[key] = self._parse_value(value, target_type) elif target_type is set and isinstance(value, list): result[key] = set(value) elif target_type is Algorithm and isinstance(value, str): result[key] = Algorithm(value.lower()) elif isinstance(value, target_type): result[key] = value elif target_type is float and isinstance(value, int): result[key] = float(value) else: msg = f"Invalid type for '{key}': expected {target_type.__name__}, got {type(value).__name__}" raise ConfigurationError(msg) return result def _load_dotenv_file(self, file_path: Path) -> dict[str, str]: """Load environment variables from a .env file. Args: file_path: Path to the .env file. Returns: Dictionary of environment variable names to values. Raises: ConfigurationError: If the file cannot be read or parsed. """ if not file_path.exists(): msg = f"Configuration file not found: {file_path}" raise ConfigurationError(msg) env_vars: dict[str, str] = {} try: with file_path.open(encoding="utf-8") as f: for line_num, line in enumerate(f, start=1): line = line.strip() # Skip empty lines and comments if not line or line.startswith("#"): continue # Parse key=value pairs if "=" not in line: msg = f"Invalid line {line_num} in {file_path}: missing '='" raise ConfigurationError(msg) key, _, value = line.partition("=") key = key.strip() value = value.strip() # Remove surrounding quotes if present if ( len(value) >= 2 and value[0] == value[-1] and value[0] in ('"', "'") ): value = value[1:-1] env_vars[key] = value except OSError as e: msg = f"Failed to read configuration file {file_path}: {e}" raise ConfigurationError(msg) from e return env_vars def _load_json_file(self, file_path: Path) -> Any: """Load configuration from a JSON file. Args: file_path: Path to the JSON file. Returns: Parsed JSON data (could be any JSON type). Raises: ConfigurationError: If the file cannot be read or parsed. """ if not file_path.exists(): msg = f"Configuration file not found: {file_path}" raise ConfigurationError(msg) try: with file_path.open(encoding="utf-8") as f: data: Any = json.load(f) except json.JSONDecodeError as e: msg = f"Invalid JSON in {file_path}: {e}" raise ConfigurationError(msg) from e except OSError as e: msg = f"Failed to read configuration file {file_path}: {e}" raise ConfigurationError(msg) from e return data def _extract_env_config( self, prefix: str, field_types: dict[str, type[Any]], env_source: Mapping[str, str] | None = None, ) -> dict[str, str]: """Extract configuration from environment variables. Args: prefix: The prefix to look for (e.g., "RATE_LIMIT_" or "GLOBAL_"). field_types: Mapping of field names to their expected types. env_source: Optional source of environment variables. Defaults to os.environ. Returns: Dictionary of field names to their string values. """ source = env_source if env_source is not None else os.environ full_prefix = f"{self._env_prefix}{prefix}" result: dict[str, str] = {} for key, value in source.items(): if key.startswith(full_prefix): field_name = key[len(full_prefix) :].lower() if field_name in field_types: result[field_name] = value return result def load_rate_limit_config_from_env( self, env_source: Mapping[str, str] | None = None, **overrides: Any, ) -> RateLimitConfig: """Load RateLimitConfig from environment variables. Environment variables should be prefixed with FASTAPI_TRAFFIC_RATE_LIMIT_ (e.g., FASTAPI_TRAFFIC_RATE_LIMIT_LIMIT=100). Args: env_source: Optional source of environment variables. Defaults to os.environ. **overrides: Additional values to override loaded config. Returns: Configured RateLimitConfig instance. Raises: ConfigurationError: If configuration is invalid. """ raw_config = self._extract_env_config( "RATE_LIMIT_", _RATE_LIMIT_FIELD_TYPES, env_source ) config_dict = self._validate_and_convert(raw_config, _RATE_LIMIT_FIELD_TYPES) # Apply overrides for key, value in overrides.items(): if key in _NON_LOADABLE_FIELDS or key in _RATE_LIMIT_FIELD_TYPES: config_dict[key] = value # Ensure required field 'limit' is present if "limit" not in config_dict: msg = "Required field 'limit' not found in environment configuration" raise ConfigurationError(msg) return RateLimitConfig(**config_dict) def load_rate_limit_config_from_dotenv( self, file_path: str | Path, **overrides: Any, ) -> RateLimitConfig: """Load RateLimitConfig from a .env file. Args: file_path: Path to the .env file. **overrides: Additional values to override loaded config. Returns: Configured RateLimitConfig instance. Raises: ConfigurationError: If configuration is invalid. """ path = Path(file_path) env_vars = self._load_dotenv_file(path) return self.load_rate_limit_config_from_env(env_vars, **overrides) def load_rate_limit_config_from_json( self, file_path: str | Path, **overrides: Any, ) -> RateLimitConfig: """Load RateLimitConfig from a JSON file. Args: file_path: Path to the JSON file. **overrides: Additional values to override loaded config. Returns: Configured RateLimitConfig instance. Raises: ConfigurationError: If configuration is invalid. """ path = Path(file_path) raw_config = self._load_json_file(path) if not isinstance(raw_config, dict): msg = "JSON root must be an object" raise ConfigurationError(msg) config_dict = self._validate_and_convert(raw_config, _RATE_LIMIT_FIELD_TYPES) # Apply overrides for key, value in overrides.items(): if key in _NON_LOADABLE_FIELDS or key in _RATE_LIMIT_FIELD_TYPES: config_dict[key] = value # Ensure required field 'limit' is present if "limit" not in config_dict: msg = "Required field 'limit' not found in JSON configuration" raise ConfigurationError(msg) return RateLimitConfig(**config_dict) def load_global_config_from_env( self, env_source: Mapping[str, str] | None = None, **overrides: Any, ) -> GlobalConfig: """Load GlobalConfig from environment variables. Environment variables should be prefixed with FASTAPI_TRAFFIC_GLOBAL_ (e.g., FASTAPI_TRAFFIC_GLOBAL_ENABLED=true). Args: env_source: Optional source of environment variables. Defaults to os.environ. **overrides: Additional values to override loaded config. Returns: Configured GlobalConfig instance. Raises: ConfigurationError: If configuration is invalid. """ raw_config = self._extract_env_config( "GLOBAL_", _GLOBAL_FIELD_TYPES, env_source ) config_dict = self._validate_and_convert(raw_config, _GLOBAL_FIELD_TYPES) # Apply overrides for key, value in overrides.items(): if key in _NON_LOADABLE_FIELDS or key in _GLOBAL_FIELD_TYPES: config_dict[key] = value return GlobalConfig(**config_dict) def load_global_config_from_dotenv( self, file_path: str | Path, **overrides: Any, ) -> GlobalConfig: """Load GlobalConfig from a .env file. Args: file_path: Path to the .env file. **overrides: Additional values to override loaded config. Returns: Configured GlobalConfig instance. Raises: ConfigurationError: If configuration is invalid. """ path = Path(file_path) env_vars = self._load_dotenv_file(path) return self.load_global_config_from_env(env_vars, **overrides) def load_global_config_from_json( self, file_path: str | Path, **overrides: Any, ) -> GlobalConfig: """Load GlobalConfig from a JSON file. Args: file_path: Path to the JSON file. **overrides: Additional values to override loaded config. Returns: Configured GlobalConfig instance. Raises: ConfigurationError: If configuration is invalid. """ path = Path(file_path) raw_config = self._load_json_file(path) config_dict = self._validate_and_convert(raw_config, _GLOBAL_FIELD_TYPES) # Apply overrides for key, value in overrides.items(): if key in _NON_LOADABLE_FIELDS or key in _GLOBAL_FIELD_TYPES: config_dict[key] = value return GlobalConfig(**config_dict) # Convenience functions for direct usage _default_loader: ConfigLoader | None = None def _get_default_loader() -> ConfigLoader: """Get or create the default config loader.""" global _default_loader if _default_loader is None: _default_loader = ConfigLoader() return _default_loader def load_rate_limit_config( source: str | Path, **overrides: Any, ) -> RateLimitConfig: """Load RateLimitConfig from a file (auto-detects format). Args: source: Path to configuration file (.env or .json). **overrides: Additional values to override loaded config. Returns: Configured RateLimitConfig instance. Raises: ConfigurationError: If configuration is invalid or format unknown. """ loader = _get_default_loader() path = Path(source) if path.suffix.lower() == ".json": return loader.load_rate_limit_config_from_json(path, **overrides) if path.suffix.lower() in (".env", "") or path.name.startswith(".env"): return loader.load_rate_limit_config_from_dotenv(path, **overrides) msg = f"Unknown configuration file format: {path.suffix}" raise ConfigurationError(msg) def load_global_config( source: str | Path, **overrides: Any, ) -> GlobalConfig: """Load GlobalConfig from a file (auto-detects format). Args: source: Path to configuration file (.env or .json). **overrides: Additional values to override loaded config. Returns: Configured GlobalConfig instance. Raises: ConfigurationError: If configuration is invalid or format unknown. """ loader = _get_default_loader() path = Path(source) if path.suffix.lower() == ".json": return loader.load_global_config_from_json(path, **overrides) if path.suffix.lower() in (".env", "") or path.name.startswith(".env"): return loader.load_global_config_from_dotenv(path, **overrides) msg = f"Unknown configuration file format: {path.suffix}" raise ConfigurationError(msg) def load_rate_limit_config_from_env(**overrides: Any) -> RateLimitConfig: """Load RateLimitConfig from environment variables. Args: **overrides: Additional values to override loaded config. Returns: Configured RateLimitConfig instance. """ return _get_default_loader().load_rate_limit_config_from_env(**overrides) def load_global_config_from_env(**overrides: Any) -> GlobalConfig: """Load GlobalConfig from environment variables. Args: **overrides: Additional values to override loaded config. Returns: Configured GlobalConfig instance. """ return _get_default_loader().load_global_config_from_env(**overrides)