- Refactor Redis backend connection handling and pool management - Update algorithm implementations with improved type annotations - Enhance config loader validation with stricter Pydantic schemas - Improve decorator and middleware error handling - Expand example scripts with better docstrings and usage patterns - Add new 00_basic_usage.py example for quick start - Reorganize examples directory structure - Fix type annotation inconsistencies across core modules - Update dependencies in pyproject.toml
592 lines
19 KiB
Python
592 lines
19 KiB
Python
"""Configuration loader for rate limiting settings from .env and .json files."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, Any, TypeVar, cast
|
|
|
|
from pydantic import BaseModel, ConfigDict, ValidationError, field_validator
|
|
|
|
from fastapi_traffic.core.algorithms import Algorithm
|
|
from fastapi_traffic.core.config import GlobalConfig, RateLimitConfig
|
|
from fastapi_traffic.exceptions import ConfigurationError
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Mapping
|
|
|
|
T = TypeVar("T", RateLimitConfig, GlobalConfig)
|
|
|
|
# Environment variable prefix for config values
|
|
ENV_PREFIX = "FASTAPI_TRAFFIC_"
|
|
|
|
# Fields that cannot be loaded from config files (callables, complex objects)
|
|
_NON_LOADABLE_FIELDS: frozenset[str] = frozenset(
|
|
{
|
|
"key_extractor",
|
|
"exempt_when",
|
|
"on_blocked",
|
|
"backend",
|
|
}
|
|
)
|
|
|
|
|
|
class _RateLimitSchema(BaseModel):
|
|
"""Pydantic schema for validating rate limit configuration input."""
|
|
|
|
model_config = ConfigDict(extra="forbid")
|
|
|
|
limit: int
|
|
window_size: float = 60.0
|
|
algorithm: Algorithm = Algorithm.SLIDING_WINDOW_COUNTER
|
|
key_prefix: str = "ratelimit"
|
|
burst_size: int | None = None
|
|
include_headers: bool = True
|
|
error_message: str = "Rate limit exceeded"
|
|
status_code: int = 429
|
|
skip_on_error: bool = False
|
|
cost: int = 1
|
|
|
|
@field_validator("algorithm", mode="before")
|
|
@classmethod
|
|
def _normalize_algorithm(cls, v: Any) -> Any:
|
|
if isinstance(v, str):
|
|
return v.lower()
|
|
return v
|
|
|
|
|
|
class _GlobalSchema(BaseModel):
|
|
"""Pydantic schema for validating global configuration input."""
|
|
|
|
model_config = ConfigDict(extra="forbid")
|
|
|
|
enabled: bool = True
|
|
default_limit: int = 100
|
|
default_window_size: float = 60.0
|
|
default_algorithm: Algorithm = Algorithm.SLIDING_WINDOW_COUNTER
|
|
key_prefix: str = "fastapi_traffic"
|
|
include_headers: bool = True
|
|
error_message: str = "Rate limit exceeded. Please try again later."
|
|
status_code: int = 429
|
|
skip_on_error: bool = False
|
|
exempt_ips: set[str] = set()
|
|
exempt_paths: set[str] = set()
|
|
headers_prefix: str = "X-RateLimit"
|
|
|
|
@field_validator("default_algorithm", mode="before")
|
|
@classmethod
|
|
def _normalize_algorithm(cls, v: Any) -> Any:
|
|
if isinstance(v, str):
|
|
return v.lower()
|
|
return v
|
|
|
|
|
|
# Known field names per schema (used for env-var extraction)
|
|
_RATE_LIMIT_FIELDS: frozenset[str] = frozenset(_RateLimitSchema.model_fields.keys())
|
|
_GLOBAL_FIELDS: frozenset[str] = frozenset(_GlobalSchema.model_fields.keys())
|
|
|
|
|
|
def _check_non_loadable(data: Mapping[str, Any]) -> None:
|
|
"""Raise ConfigurationError if data contains non-loadable fields."""
|
|
for key in data:
|
|
if key in _NON_LOADABLE_FIELDS:
|
|
msg = f"Field '{key}' cannot be loaded from configuration files"
|
|
raise ConfigurationError(msg)
|
|
|
|
|
|
def _format_validation_error(exc: ValidationError) -> str:
|
|
"""Convert a Pydantic ValidationError to a user-friendly message."""
|
|
errors = exc.errors()
|
|
if not errors:
|
|
return str(exc)
|
|
|
|
err = errors[0]
|
|
loc = ".".join(str(p) for p in err["loc"]) if err["loc"] else "unknown"
|
|
err_type = err["type"]
|
|
msg = err["msg"]
|
|
ctx = err.get("ctx", {})
|
|
|
|
if err_type == "extra_forbidden":
|
|
return f"Unknown configuration field: '{loc}'"
|
|
|
|
if err_type in ("int_parsing", "float_parsing"):
|
|
input_val = ctx.get("error", err.get("input", ""))
|
|
return f"Cannot parse value '{input_val}' as {loc}: {msg}"
|
|
|
|
if err_type == "bool_parsing":
|
|
return f"Cannot parse value as bool for '{loc}': {msg}"
|
|
|
|
if "enum" in err_type or err_type == "value_error":
|
|
input_val = err.get("input", "")
|
|
return f"Cannot parse value '{input_val}' as {loc}: {msg}"
|
|
|
|
return f"Invalid value for '{loc}': {msg}"
|
|
|
|
|
|
class ConfigLoader:
|
|
"""Loader for rate limiting configuration from various sources.
|
|
|
|
Supports loading configuration from:
|
|
- Environment variables (with FASTAPI_TRAFFIC_ prefix)
|
|
- .env files
|
|
- JSON files
|
|
|
|
Example usage:
|
|
>>> loader = ConfigLoader()
|
|
>>> global_config = loader.load_global_config_from_env()
|
|
>>> rate_config = loader.load_rate_limit_config_from_json("config.json")
|
|
"""
|
|
|
|
__slots__ = ("_env_prefix",)
|
|
|
|
def __init__(self, env_prefix: str = ENV_PREFIX) -> None:
|
|
"""Initialize the config loader.
|
|
|
|
Args:
|
|
env_prefix: Prefix for environment variables. Defaults to "FASTAPI_TRAFFIC_".
|
|
"""
|
|
self._env_prefix = env_prefix
|
|
|
|
def _load_dotenv_file(self, file_path: Path) -> dict[str, str]:
|
|
"""Load environment variables from a .env file.
|
|
|
|
Args:
|
|
file_path: Path to the .env file.
|
|
|
|
Returns:
|
|
Dictionary of environment variable names to values.
|
|
|
|
Raises:
|
|
ConfigurationError: If the file cannot be read or parsed.
|
|
"""
|
|
if not file_path.exists():
|
|
msg = f"Configuration file not found: {file_path}"
|
|
raise ConfigurationError(msg)
|
|
|
|
env_vars: dict[str, str] = {}
|
|
|
|
try:
|
|
with file_path.open(encoding="utf-8") as f:
|
|
for line_num, line in enumerate(f, start=1):
|
|
line = line.strip()
|
|
|
|
# Skip empty lines and comments
|
|
if not line or line.startswith("#"):
|
|
continue
|
|
|
|
# Parse key=value pairs
|
|
if "=" not in line:
|
|
msg = f"Invalid line {line_num} in {file_path}: missing '='"
|
|
raise ConfigurationError(msg)
|
|
|
|
key, _, value = line.partition("=")
|
|
key = key.strip()
|
|
value = value.strip()
|
|
|
|
# Remove surrounding quotes if present
|
|
if (
|
|
len(value) >= 2
|
|
and value[0] == value[-1]
|
|
and value[0] in ('"', "'")
|
|
):
|
|
value = value[1:-1]
|
|
|
|
env_vars[key] = value
|
|
|
|
except OSError as e:
|
|
msg = f"Failed to read configuration file {file_path}: {e}"
|
|
raise ConfigurationError(msg) from e
|
|
|
|
return env_vars
|
|
|
|
def _load_json_file(self, file_path: Path) -> Any:
|
|
"""Load configuration from a JSON file.
|
|
|
|
Args:
|
|
file_path: Path to the JSON file.
|
|
|
|
Returns:
|
|
Parsed JSON data (could be any JSON type).
|
|
|
|
Raises:
|
|
ConfigurationError: If the file cannot be read or parsed.
|
|
"""
|
|
if not file_path.exists():
|
|
msg = f"Configuration file not found: {file_path}"
|
|
raise ConfigurationError(msg)
|
|
|
|
try:
|
|
with file_path.open(encoding="utf-8") as f:
|
|
data: Any = json.load(f)
|
|
except json.JSONDecodeError as e:
|
|
msg = f"Invalid JSON in {file_path}: {e}"
|
|
raise ConfigurationError(msg) from e
|
|
except OSError as e:
|
|
msg = f"Failed to read configuration file {file_path}: {e}"
|
|
raise ConfigurationError(msg) from e
|
|
|
|
return data
|
|
|
|
def _extract_env_config(
|
|
self,
|
|
prefix: str,
|
|
known_fields: frozenset[str],
|
|
env_source: Mapping[str, str] | None = None,
|
|
) -> dict[str, str]:
|
|
"""Extract configuration from environment variables.
|
|
|
|
Args:
|
|
prefix: The prefix to look for (e.g., "RATE_LIMIT_" or "GLOBAL_").
|
|
known_fields: Set of known field names.
|
|
env_source: Optional source of environment variables. Defaults to os.environ.
|
|
|
|
Returns:
|
|
Dictionary of field names to their string values.
|
|
"""
|
|
source = env_source if env_source is not None else os.environ
|
|
full_prefix = f"{self._env_prefix}{prefix}"
|
|
result: dict[str, str] = {}
|
|
|
|
for key, value in source.items():
|
|
if key.startswith(full_prefix):
|
|
field_name = key[len(full_prefix) :].lower()
|
|
if field_name in known_fields:
|
|
result[field_name] = value
|
|
|
|
return result
|
|
|
|
def _parse_set_from_string(self, data: dict[str, Any]) -> dict[str, Any]:
|
|
"""Pre-process comma-separated string values into lists for set fields.
|
|
|
|
This handles the env-var case where sets are represented as
|
|
comma-separated strings (e.g., "127.0.0.1, 10.0.0.1").
|
|
"""
|
|
result = dict(data)
|
|
for key in ("exempt_ips", "exempt_paths"):
|
|
if key in result and isinstance(result[key], str):
|
|
value = result[key].strip()
|
|
if not value:
|
|
result[key] = []
|
|
else:
|
|
result[key] = [
|
|
item.strip() for item in value.split(",") if item.strip()
|
|
]
|
|
return result
|
|
|
|
def load_rate_limit_config_from_env(
|
|
self,
|
|
env_source: Mapping[str, str] | None = None,
|
|
**overrides: Any,
|
|
) -> RateLimitConfig:
|
|
"""Load RateLimitConfig from environment variables.
|
|
|
|
Environment variables should be prefixed with FASTAPI_TRAFFIC_RATE_LIMIT_
|
|
(e.g., FASTAPI_TRAFFIC_RATE_LIMIT_LIMIT=100).
|
|
|
|
Args:
|
|
env_source: Optional source of environment variables. Defaults to os.environ.
|
|
**overrides: Additional values to override loaded config.
|
|
|
|
Returns:
|
|
Configured RateLimitConfig instance.
|
|
|
|
Raises:
|
|
ConfigurationError: If configuration is invalid.
|
|
"""
|
|
raw_config = self._extract_env_config(
|
|
"RATE_LIMIT_", _RATE_LIMIT_FIELDS, env_source
|
|
)
|
|
|
|
_check_non_loadable(raw_config)
|
|
|
|
# Merge loadable overrides before validation so required fields can be supplied
|
|
non_loadable_overrides: dict[str, Any] = {}
|
|
for key, value in overrides.items():
|
|
if key in _NON_LOADABLE_FIELDS:
|
|
non_loadable_overrides[key] = value
|
|
elif key in _RATE_LIMIT_FIELDS:
|
|
raw_config[key] = value
|
|
|
|
try:
|
|
schema = _RateLimitSchema(**raw_config) # type: ignore[arg-type] # Pydantic coerces str→typed values at runtime
|
|
except ValidationError as e:
|
|
raise ConfigurationError(_format_validation_error(e)) from e
|
|
|
|
config_dict = schema.model_dump(exclude_defaults=True)
|
|
|
|
# Apply non-loadable overrides (callables, etc.)
|
|
config_dict.update(non_loadable_overrides)
|
|
|
|
# Ensure required field 'limit' is present
|
|
if "limit" not in config_dict:
|
|
msg = "Required field 'limit' not found in environment configuration"
|
|
raise ConfigurationError(msg)
|
|
|
|
return RateLimitConfig(**config_dict)
|
|
|
|
def load_rate_limit_config_from_dotenv(
|
|
self,
|
|
file_path: str | Path,
|
|
**overrides: Any,
|
|
) -> RateLimitConfig:
|
|
"""Load RateLimitConfig from a .env file.
|
|
|
|
Args:
|
|
file_path: Path to the .env file.
|
|
**overrides: Additional values to override loaded config.
|
|
|
|
Returns:
|
|
Configured RateLimitConfig instance.
|
|
|
|
Raises:
|
|
ConfigurationError: If configuration is invalid.
|
|
"""
|
|
path = Path(file_path)
|
|
env_vars = self._load_dotenv_file(path)
|
|
return self.load_rate_limit_config_from_env(env_vars, **overrides)
|
|
|
|
def load_rate_limit_config_from_json(
|
|
self,
|
|
file_path: str | Path,
|
|
**overrides: Any,
|
|
) -> RateLimitConfig:
|
|
"""Load RateLimitConfig from a JSON file.
|
|
|
|
Args:
|
|
file_path: Path to the JSON file.
|
|
**overrides: Additional values to override loaded config.
|
|
|
|
Returns:
|
|
Configured RateLimitConfig instance.
|
|
|
|
Raises:
|
|
ConfigurationError: If configuration is invalid.
|
|
"""
|
|
path = Path(file_path)
|
|
raw_config = self._load_json_file(path)
|
|
if not isinstance(raw_config, dict):
|
|
msg = "JSON root must be an object"
|
|
raise ConfigurationError(msg)
|
|
|
|
_check_non_loadable(cast("dict[str, Any]", raw_config))
|
|
|
|
# Merge loadable overrides before validation so required fields can be supplied
|
|
non_loadable_overrides: dict[str, Any] = {}
|
|
for key, value in overrides.items():
|
|
if key in _NON_LOADABLE_FIELDS:
|
|
non_loadable_overrides[key] = value
|
|
elif key in _RATE_LIMIT_FIELDS:
|
|
raw_config[key] = value
|
|
|
|
try:
|
|
schema = _RateLimitSchema(**raw_config) # type: ignore[arg-type] # Pydantic coerces str→typed values at runtime
|
|
except ValidationError as e:
|
|
raise ConfigurationError(_format_validation_error(e)) from e
|
|
|
|
config_dict = schema.model_dump(exclude_defaults=True)
|
|
|
|
# Apply non-loadable overrides (callables, etc.)
|
|
config_dict.update(non_loadable_overrides)
|
|
|
|
# Ensure required field 'limit' is present
|
|
if "limit" not in config_dict:
|
|
msg = "Required field 'limit' not found in JSON configuration"
|
|
raise ConfigurationError(msg)
|
|
|
|
return RateLimitConfig(**config_dict)
|
|
|
|
def load_global_config_from_env(
|
|
self,
|
|
env_source: Mapping[str, str] | None = None,
|
|
**overrides: Any,
|
|
) -> GlobalConfig:
|
|
"""Load GlobalConfig from environment variables.
|
|
|
|
Environment variables should be prefixed with FASTAPI_TRAFFIC_GLOBAL_
|
|
(e.g., FASTAPI_TRAFFIC_GLOBAL_ENABLED=true).
|
|
|
|
Args:
|
|
env_source: Optional source of environment variables. Defaults to os.environ.
|
|
**overrides: Additional values to override loaded config.
|
|
|
|
Returns:
|
|
Configured GlobalConfig instance.
|
|
|
|
Raises:
|
|
ConfigurationError: If configuration is invalid.
|
|
"""
|
|
raw_config = self._extract_env_config("GLOBAL_", _GLOBAL_FIELDS, env_source)
|
|
|
|
_check_non_loadable(raw_config)
|
|
|
|
# Pre-process comma-separated strings into lists for set fields
|
|
processed = self._parse_set_from_string(raw_config)
|
|
|
|
try:
|
|
schema = _GlobalSchema(**processed) # type: ignore[arg-type] # Pydantic coerces str→typed values at runtime
|
|
except ValidationError as e:
|
|
raise ConfigurationError(_format_validation_error(e)) from e
|
|
|
|
config_dict = schema.model_dump(exclude_defaults=True)
|
|
|
|
# Apply overrides
|
|
for key, value in overrides.items():
|
|
if key in _NON_LOADABLE_FIELDS or key in _GLOBAL_FIELDS:
|
|
config_dict[key] = value
|
|
|
|
return GlobalConfig(**config_dict)
|
|
|
|
def load_global_config_from_dotenv(
|
|
self,
|
|
file_path: str | Path,
|
|
**overrides: Any,
|
|
) -> GlobalConfig:
|
|
"""Load GlobalConfig from a .env file.
|
|
|
|
Args:
|
|
file_path: Path to the .env file.
|
|
**overrides: Additional values to override loaded config.
|
|
|
|
Returns:
|
|
Configured GlobalConfig instance.
|
|
|
|
Raises:
|
|
ConfigurationError: If configuration is invalid.
|
|
"""
|
|
path = Path(file_path)
|
|
env_vars = self._load_dotenv_file(path)
|
|
return self.load_global_config_from_env(env_vars, **overrides)
|
|
|
|
def load_global_config_from_json(
|
|
self,
|
|
file_path: str | Path,
|
|
**overrides: Any,
|
|
) -> GlobalConfig:
|
|
"""Load GlobalConfig from a JSON file.
|
|
|
|
Args:
|
|
file_path: Path to the JSON file.
|
|
**overrides: Additional values to override loaded config.
|
|
|
|
Returns:
|
|
Configured GlobalConfig instance.
|
|
|
|
Raises:
|
|
ConfigurationError: If configuration is invalid.
|
|
"""
|
|
path = Path(file_path)
|
|
raw_config = self._load_json_file(path)
|
|
|
|
if not isinstance(raw_config, dict):
|
|
msg = "JSON root must be an object"
|
|
raise ConfigurationError(msg)
|
|
|
|
_check_non_loadable(cast("dict[str, Any]", raw_config))
|
|
|
|
try:
|
|
schema = _GlobalSchema(**cast("dict[str, Any]", raw_config))
|
|
except ValidationError as e:
|
|
raise ConfigurationError(_format_validation_error(e)) from e
|
|
|
|
config_dict = schema.model_dump(exclude_defaults=True)
|
|
|
|
# Apply overrides
|
|
for key, value in overrides.items():
|
|
if key in _NON_LOADABLE_FIELDS or key in _GLOBAL_FIELDS:
|
|
config_dict[key] = value
|
|
|
|
return GlobalConfig(**config_dict)
|
|
|
|
|
|
# Convenience functions for direct usage
|
|
_default_loader: ConfigLoader | None = None
|
|
|
|
|
|
def _get_default_loader() -> ConfigLoader:
|
|
"""Get or create the default config loader."""
|
|
global _default_loader
|
|
if _default_loader is None:
|
|
_default_loader = ConfigLoader()
|
|
return _default_loader
|
|
|
|
|
|
def load_rate_limit_config(
|
|
source: str | Path,
|
|
**overrides: Any,
|
|
) -> RateLimitConfig:
|
|
"""Load RateLimitConfig from a file (auto-detects format).
|
|
|
|
Args:
|
|
source: Path to configuration file (.env or .json).
|
|
**overrides: Additional values to override loaded config.
|
|
|
|
Returns:
|
|
Configured RateLimitConfig instance.
|
|
|
|
Raises:
|
|
ConfigurationError: If configuration is invalid or format unknown.
|
|
"""
|
|
loader = _get_default_loader()
|
|
path = Path(source)
|
|
|
|
if path.suffix.lower() == ".json":
|
|
return loader.load_rate_limit_config_from_json(path, **overrides)
|
|
if path.suffix.lower() in (".env", "") or path.name.startswith(".env"):
|
|
return loader.load_rate_limit_config_from_dotenv(path, **overrides)
|
|
|
|
msg = f"Unknown configuration file format: {path.suffix}"
|
|
raise ConfigurationError(msg)
|
|
|
|
|
|
def load_global_config(
|
|
source: str | Path,
|
|
**overrides: Any,
|
|
) -> GlobalConfig:
|
|
"""Load GlobalConfig from a file (auto-detects format).
|
|
|
|
Args:
|
|
source: Path to configuration file (.env or .json).
|
|
**overrides: Additional values to override loaded config.
|
|
|
|
Returns:
|
|
Configured GlobalConfig instance.
|
|
|
|
Raises:
|
|
ConfigurationError: If configuration is invalid or format unknown.
|
|
"""
|
|
loader = _get_default_loader()
|
|
path = Path(source)
|
|
|
|
if path.suffix.lower() == ".json":
|
|
return loader.load_global_config_from_json(path, **overrides)
|
|
if path.suffix.lower() in (".env", "") or path.name.startswith(".env"):
|
|
return loader.load_global_config_from_dotenv(path, **overrides)
|
|
|
|
msg = f"Unknown configuration file format: {path.suffix}"
|
|
raise ConfigurationError(msg)
|
|
|
|
|
|
def load_rate_limit_config_from_env(**overrides: Any) -> RateLimitConfig:
|
|
"""Load RateLimitConfig from environment variables.
|
|
|
|
Args:
|
|
**overrides: Additional values to override loaded config.
|
|
|
|
Returns:
|
|
Configured RateLimitConfig instance.
|
|
"""
|
|
return _get_default_loader().load_rate_limit_config_from_env(**overrides)
|
|
|
|
|
|
def load_global_config_from_env(**overrides: Any) -> GlobalConfig:
|
|
"""Load GlobalConfig from environment variables.
|
|
|
|
Args:
|
|
**overrides: Additional values to override loaded config.
|
|
|
|
Returns:
|
|
Configured GlobalConfig instance.
|
|
"""
|
|
return _get_default_loader().load_global_config_from_env(**overrides)
|