refactor: improve config loader validation with Pydantic schemas
- Replace manual field type validation with Pydantic model schemas - Add pydantic>=2.0 as core dependency - Fix sync wrapper in decorator to properly handle rate limiting - Update pyright settings for stricter type checking - Fix repository URL in pyproject.toml - Remove unused main.py - Update test assertions for new validation error format
This commit is contained in:
@@ -7,6 +7,8 @@ import os
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, ValidationError, field_validator
|
||||
|
||||
from fastapi_traffic.core.algorithms import Algorithm
|
||||
from fastapi_traffic.core.config import GlobalConfig, RateLimitConfig
|
||||
from fastapi_traffic.exceptions import ConfigurationError
|
||||
@@ -19,35 +21,6 @@ T = TypeVar("T", RateLimitConfig, GlobalConfig)
|
||||
# Environment variable prefix for config values
|
||||
ENV_PREFIX = "FASTAPI_TRAFFIC_"
|
||||
|
||||
# Mapping of config field names to their types for validation
|
||||
_RATE_LIMIT_FIELD_TYPES: dict[str, type[Any]] = {
|
||||
"limit": int,
|
||||
"window_size": float,
|
||||
"algorithm": Algorithm,
|
||||
"key_prefix": str,
|
||||
"burst_size": int,
|
||||
"include_headers": bool,
|
||||
"error_message": str,
|
||||
"status_code": int,
|
||||
"skip_on_error": bool,
|
||||
"cost": int,
|
||||
}
|
||||
|
||||
_GLOBAL_FIELD_TYPES: dict[str, type[Any]] = {
|
||||
"enabled": bool,
|
||||
"default_limit": int,
|
||||
"default_window_size": float,
|
||||
"default_algorithm": Algorithm,
|
||||
"key_prefix": str,
|
||||
"include_headers": bool,
|
||||
"error_message": str,
|
||||
"status_code": int,
|
||||
"skip_on_error": bool,
|
||||
"exempt_ips": set,
|
||||
"exempt_paths": set,
|
||||
"headers_prefix": str,
|
||||
}
|
||||
|
||||
# Fields that cannot be loaded from config files (callables, complex objects)
|
||||
_NON_LOADABLE_FIELDS: frozenset[str] = frozenset(
|
||||
{
|
||||
@@ -59,6 +32,98 @@ _NON_LOADABLE_FIELDS: frozenset[str] = frozenset(
|
||||
)
|
||||
|
||||
|
||||
class _RateLimitSchema(BaseModel):
|
||||
"""Pydantic schema for validating rate limit configuration input."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
limit: int
|
||||
window_size: float = 60.0
|
||||
algorithm: Algorithm = Algorithm.SLIDING_WINDOW_COUNTER
|
||||
key_prefix: str = "ratelimit"
|
||||
burst_size: int | None = None
|
||||
include_headers: bool = True
|
||||
error_message: str = "Rate limit exceeded"
|
||||
status_code: int = 429
|
||||
skip_on_error: bool = False
|
||||
cost: int = 1
|
||||
|
||||
@field_validator("algorithm", mode="before")
|
||||
@classmethod
|
||||
def _normalize_algorithm(cls, v: Any) -> Any:
|
||||
if isinstance(v, str):
|
||||
return v.lower()
|
||||
return v
|
||||
|
||||
|
||||
class _GlobalSchema(BaseModel):
|
||||
"""Pydantic schema for validating global configuration input."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
enabled: bool = True
|
||||
default_limit: int = 100
|
||||
default_window_size: float = 60.0
|
||||
default_algorithm: Algorithm = Algorithm.SLIDING_WINDOW_COUNTER
|
||||
key_prefix: str = "fastapi_traffic"
|
||||
include_headers: bool = True
|
||||
error_message: str = "Rate limit exceeded. Please try again later."
|
||||
status_code: int = 429
|
||||
skip_on_error: bool = False
|
||||
exempt_ips: set[str] = set()
|
||||
exempt_paths: set[str] = set()
|
||||
headers_prefix: str = "X-RateLimit"
|
||||
|
||||
@field_validator("default_algorithm", mode="before")
|
||||
@classmethod
|
||||
def _normalize_algorithm(cls, v: Any) -> Any:
|
||||
if isinstance(v, str):
|
||||
return v.lower()
|
||||
return v
|
||||
|
||||
|
||||
# Known field names per schema (used for env-var extraction)
|
||||
_RATE_LIMIT_FIELDS: frozenset[str] = frozenset(_RateLimitSchema.model_fields.keys())
|
||||
_GLOBAL_FIELDS: frozenset[str] = frozenset(_GlobalSchema.model_fields.keys())
|
||||
|
||||
|
||||
def _check_non_loadable(data: Mapping[str, Any]) -> None:
|
||||
"""Raise ConfigurationError if data contains non-loadable fields."""
|
||||
for key in data:
|
||||
if key in _NON_LOADABLE_FIELDS:
|
||||
msg = f"Field '{key}' cannot be loaded from configuration files"
|
||||
raise ConfigurationError(msg)
|
||||
|
||||
|
||||
def _format_validation_error(exc: ValidationError) -> str:
|
||||
"""Convert a Pydantic ValidationError to a user-friendly message."""
|
||||
errors = exc.errors()
|
||||
if not errors:
|
||||
return str(exc)
|
||||
|
||||
err = errors[0]
|
||||
loc = ".".join(str(p) for p in err["loc"]) if err["loc"] else "unknown"
|
||||
err_type = err["type"]
|
||||
msg = err["msg"]
|
||||
ctx = err.get("ctx", {})
|
||||
|
||||
if err_type == "extra_forbidden":
|
||||
return f"Unknown configuration field: '{loc}'"
|
||||
|
||||
if err_type in ("int_parsing", "float_parsing"):
|
||||
input_val = ctx.get("error", err.get("input", ""))
|
||||
return f"Cannot parse value '{input_val}' as {loc}: {msg}"
|
||||
|
||||
if err_type == "bool_parsing":
|
||||
return f"Cannot parse value as bool for '{loc}': {msg}"
|
||||
|
||||
if "enum" in err_type or err_type == "value_error":
|
||||
input_val = err.get("input", "")
|
||||
return f"Cannot parse value '{input_val}' as {loc}: {msg}"
|
||||
|
||||
return f"Invalid value for '{loc}': {msg}"
|
||||
|
||||
|
||||
class ConfigLoader:
|
||||
"""Loader for rate limiting configuration from various sources.
|
||||
|
||||
@@ -83,88 +148,6 @@ class ConfigLoader:
|
||||
"""
|
||||
self._env_prefix = env_prefix
|
||||
|
||||
def _parse_value(self, value: str, target_type: type[Any]) -> Any:
|
||||
"""Parse a string value to the target type.
|
||||
|
||||
Args:
|
||||
value: The string value to parse.
|
||||
target_type: The target type to convert to.
|
||||
|
||||
Returns:
|
||||
The parsed value.
|
||||
|
||||
Raises:
|
||||
ConfigurationError: If the value cannot be parsed.
|
||||
"""
|
||||
try:
|
||||
if target_type is bool:
|
||||
return value.lower() in ("true", "1", "yes", "on")
|
||||
if target_type is int:
|
||||
return int(value)
|
||||
if target_type is float:
|
||||
return float(value)
|
||||
if target_type is str:
|
||||
return value
|
||||
if target_type is Algorithm:
|
||||
return Algorithm(value.lower())
|
||||
if target_type is set:
|
||||
# Parse comma-separated values
|
||||
if not value.strip():
|
||||
return set()
|
||||
return {item.strip() for item in value.split(",") if item.strip()}
|
||||
except (ValueError, KeyError) as e:
|
||||
msg = f"Cannot parse value '{value}' as {target_type.__name__}: {e}"
|
||||
raise ConfigurationError(msg) from e
|
||||
|
||||
msg = f"Unsupported type: {target_type}"
|
||||
raise ConfigurationError(msg)
|
||||
|
||||
def _validate_and_convert(
|
||||
self,
|
||||
data: Mapping[str, Any],
|
||||
field_types: dict[str, type[Any]],
|
||||
) -> dict[str, Any]:
|
||||
"""Validate and convert configuration data.
|
||||
|
||||
Args:
|
||||
data: Raw configuration data.
|
||||
field_types: Mapping of field names to their expected types.
|
||||
|
||||
Returns:
|
||||
Validated and converted configuration dictionary.
|
||||
|
||||
Raises:
|
||||
ConfigurationError: If validation fails.
|
||||
"""
|
||||
result: dict[str, Any] = {}
|
||||
|
||||
for key, value in data.items():
|
||||
if key in _NON_LOADABLE_FIELDS:
|
||||
msg = f"Field '{key}' cannot be loaded from configuration files"
|
||||
raise ConfigurationError(msg)
|
||||
|
||||
if key not in field_types:
|
||||
msg = f"Unknown configuration field: '{key}'"
|
||||
raise ConfigurationError(msg)
|
||||
|
||||
target_type = field_types[key]
|
||||
|
||||
if isinstance(value, str):
|
||||
result[key] = self._parse_value(value, target_type)
|
||||
elif target_type is set and isinstance(value, list):
|
||||
result[key] = set(value)
|
||||
elif target_type is Algorithm and isinstance(value, str):
|
||||
result[key] = Algorithm(value.lower())
|
||||
elif isinstance(value, target_type):
|
||||
result[key] = value
|
||||
elif target_type is float and isinstance(value, int):
|
||||
result[key] = float(value)
|
||||
else:
|
||||
msg = f"Invalid type for '{key}': expected {target_type.__name__}, got {type(value).__name__}"
|
||||
raise ConfigurationError(msg)
|
||||
|
||||
return result
|
||||
|
||||
def _load_dotenv_file(self, file_path: Path) -> dict[str, str]:
|
||||
"""Load environment variables from a .env file.
|
||||
|
||||
@@ -248,14 +231,14 @@ class ConfigLoader:
|
||||
def _extract_env_config(
|
||||
self,
|
||||
prefix: str,
|
||||
field_types: dict[str, type[Any]],
|
||||
known_fields: frozenset[str],
|
||||
env_source: Mapping[str, str] | None = None,
|
||||
) -> dict[str, str]:
|
||||
"""Extract configuration from environment variables.
|
||||
|
||||
Args:
|
||||
prefix: The prefix to look for (e.g., "RATE_LIMIT_" or "GLOBAL_").
|
||||
field_types: Mapping of field names to their expected types.
|
||||
known_fields: Set of known field names.
|
||||
env_source: Optional source of environment variables. Defaults to os.environ.
|
||||
|
||||
Returns:
|
||||
@@ -268,11 +251,29 @@ class ConfigLoader:
|
||||
for key, value in source.items():
|
||||
if key.startswith(full_prefix):
|
||||
field_name = key[len(full_prefix) :].lower()
|
||||
if field_name in field_types:
|
||||
if field_name in known_fields:
|
||||
result[field_name] = value
|
||||
|
||||
return result
|
||||
|
||||
def _parse_set_from_string(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Pre-process comma-separated string values into lists for set fields.
|
||||
|
||||
This handles the env-var case where sets are represented as
|
||||
comma-separated strings (e.g., "127.0.0.1, 10.0.0.1").
|
||||
"""
|
||||
result = dict(data)
|
||||
for key in ("exempt_ips", "exempt_paths"):
|
||||
if key in result and isinstance(result[key], str):
|
||||
value = result[key].strip()
|
||||
if not value:
|
||||
result[key] = []
|
||||
else:
|
||||
result[key] = [
|
||||
item.strip() for item in value.split(",") if item.strip()
|
||||
]
|
||||
return result
|
||||
|
||||
def load_rate_limit_config_from_env(
|
||||
self,
|
||||
env_source: Mapping[str, str] | None = None,
|
||||
@@ -294,13 +295,21 @@ class ConfigLoader:
|
||||
ConfigurationError: If configuration is invalid.
|
||||
"""
|
||||
raw_config = self._extract_env_config(
|
||||
"RATE_LIMIT_", _RATE_LIMIT_FIELD_TYPES, env_source
|
||||
"RATE_LIMIT_", _RATE_LIMIT_FIELDS, env_source
|
||||
)
|
||||
config_dict = self._validate_and_convert(raw_config, _RATE_LIMIT_FIELD_TYPES)
|
||||
|
||||
_check_non_loadable(raw_config)
|
||||
|
||||
try:
|
||||
schema = _RateLimitSchema(**raw_config) # type: ignore[arg-type] # Pydantic coerces str→typed values at runtime
|
||||
except ValidationError as e:
|
||||
raise ConfigurationError(_format_validation_error(e)) from e
|
||||
|
||||
config_dict = schema.model_dump(exclude_defaults=True)
|
||||
|
||||
# Apply overrides
|
||||
for key, value in overrides.items():
|
||||
if key in _NON_LOADABLE_FIELDS or key in _RATE_LIMIT_FIELD_TYPES:
|
||||
if key in _NON_LOADABLE_FIELDS or key in _RATE_LIMIT_FIELDS:
|
||||
config_dict[key] = value
|
||||
|
||||
# Ensure required field 'limit' is present
|
||||
@@ -353,11 +362,19 @@ class ConfigLoader:
|
||||
if not isinstance(raw_config, dict):
|
||||
msg = "JSON root must be an object"
|
||||
raise ConfigurationError(msg)
|
||||
config_dict = self._validate_and_convert(raw_config, _RATE_LIMIT_FIELD_TYPES)
|
||||
|
||||
_check_non_loadable(raw_config)
|
||||
|
||||
try:
|
||||
schema = _RateLimitSchema(**raw_config) # type: ignore[arg-type] # Pydantic coerces str→typed values at runtime
|
||||
except ValidationError as e:
|
||||
raise ConfigurationError(_format_validation_error(e)) from e
|
||||
|
||||
config_dict = schema.model_dump(exclude_defaults=True)
|
||||
|
||||
# Apply overrides
|
||||
for key, value in overrides.items():
|
||||
if key in _NON_LOADABLE_FIELDS or key in _RATE_LIMIT_FIELD_TYPES:
|
||||
if key in _NON_LOADABLE_FIELDS or key in _RATE_LIMIT_FIELDS:
|
||||
config_dict[key] = value
|
||||
|
||||
# Ensure required field 'limit' is present
|
||||
@@ -388,13 +405,24 @@ class ConfigLoader:
|
||||
ConfigurationError: If configuration is invalid.
|
||||
"""
|
||||
raw_config = self._extract_env_config(
|
||||
"GLOBAL_", _GLOBAL_FIELD_TYPES, env_source
|
||||
"GLOBAL_", _GLOBAL_FIELDS, env_source
|
||||
)
|
||||
config_dict = self._validate_and_convert(raw_config, _GLOBAL_FIELD_TYPES)
|
||||
|
||||
_check_non_loadable(raw_config)
|
||||
|
||||
# Pre-process comma-separated strings into lists for set fields
|
||||
processed = self._parse_set_from_string(raw_config)
|
||||
|
||||
try:
|
||||
schema = _GlobalSchema(**processed) # type: ignore[arg-type] # Pydantic coerces str→typed values at runtime
|
||||
except ValidationError as e:
|
||||
raise ConfigurationError(_format_validation_error(e)) from e
|
||||
|
||||
config_dict = schema.model_dump(exclude_defaults=True)
|
||||
|
||||
# Apply overrides
|
||||
for key, value in overrides.items():
|
||||
if key in _NON_LOADABLE_FIELDS or key in _GLOBAL_FIELD_TYPES:
|
||||
if key in _NON_LOADABLE_FIELDS or key in _GLOBAL_FIELDS:
|
||||
config_dict[key] = value
|
||||
|
||||
return GlobalConfig(**config_dict)
|
||||
@@ -439,11 +467,23 @@ class ConfigLoader:
|
||||
"""
|
||||
path = Path(file_path)
|
||||
raw_config = self._load_json_file(path)
|
||||
config_dict = self._validate_and_convert(raw_config, _GLOBAL_FIELD_TYPES)
|
||||
|
||||
if not isinstance(raw_config, dict):
|
||||
msg = "JSON root must be an object"
|
||||
raise ConfigurationError(msg)
|
||||
|
||||
_check_non_loadable(raw_config)
|
||||
|
||||
try:
|
||||
schema = _GlobalSchema(**raw_config)
|
||||
except ValidationError as e:
|
||||
raise ConfigurationError(_format_validation_error(e)) from e
|
||||
|
||||
config_dict = schema.model_dump(exclude_defaults=True)
|
||||
|
||||
# Apply overrides
|
||||
for key, value in overrides.items():
|
||||
if key in _NON_LOADABLE_FIELDS or key in _GLOBAL_FIELD_TYPES:
|
||||
if key in _NON_LOADABLE_FIELDS or key in _GLOBAL_FIELDS:
|
||||
config_dict[key] = value
|
||||
|
||||
return GlobalConfig(**config_dict)
|
||||
|
||||
Reference in New Issue
Block a user