diff --git a/.gitignore b/.gitignore index dd24194..022b770 100644 --- a/.gitignore +++ b/.gitignore @@ -13,4 +13,4 @@ things-todo.md .ruff_cache .qodo .pytest_cache -.vscode \ No newline at end of file +.vscode/ \ No newline at end of file diff --git a/fastapi_traffic/__init__.py b/fastapi_traffic/__init__.py index c3c5d1e..a640449 100644 --- a/fastapi_traffic/__init__.py +++ b/fastapi_traffic/__init__.py @@ -20,7 +20,7 @@ from fastapi_traffic.exceptions import ( RateLimitExceeded, ) -__version__ = "0.1.0" +__version__ = "0.2.0" __all__ = [ "Algorithm", "Backend", diff --git a/fastapi_traffic/core/algorithms.py b/fastapi_traffic/core/algorithms.py index 8384eb2..822cab7 100644 --- a/fastapi_traffic/core/algorithms.py +++ b/fastapi_traffic/core/algorithms.py @@ -89,6 +89,8 @@ class TokenBucketAlgorithm(BaseAlgorithm): remaining=int(tokens), reset_at=now + self.window_size, window_size=self.window_size, + retry_after = (1 - tokens) / self.refill_rate + ) tokens = float(state.get("tokens", self.burst_size)) diff --git a/fastapi_traffic/core/config.py b/fastapi_traffic/core/config.py index a538fc2..22d425b 100644 --- a/fastapi_traffic/core/config.py +++ b/fastapi_traffic/core/config.py @@ -4,7 +4,7 @@ from __future__ import annotations from collections.abc import Callable from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TypeAlias from fastapi_traffic.core.algorithms import Algorithm @@ -14,7 +14,7 @@ if TYPE_CHECKING: from fastapi_traffic.backends.base import Backend -KeyExtractor = Callable[["Request"], str] +KeyExtractor: TypeAlias = Callable[["Request"], str] def default_key_extractor(request: Request) -> str: @@ -55,10 +55,10 @@ class RateLimitConfig: if self.limit <= 0: msg = "limit must be positive" raise ValueError(msg) - if self.window_size <= 0: + elif self.window_size <= 0: msg = "window_size must be positive" raise ValueError(msg) - if self.cost <= 0: + elif self.cost <= 0: msg = "cost must be positive" raise ValueError(msg) diff --git a/fastapi_traffic/core/config_loader.py b/fastapi_traffic/core/config_loader.py index 3f767b6..5d2c225 100644 --- a/fastapi_traffic/core/config_loader.py +++ b/fastapi_traffic/core/config_loader.py @@ -7,6 +7,8 @@ import os from pathlib import Path from typing import TYPE_CHECKING, Any, TypeVar +from pydantic import BaseModel, ConfigDict, ValidationError, field_validator + from fastapi_traffic.core.algorithms import Algorithm from fastapi_traffic.core.config import GlobalConfig, RateLimitConfig from fastapi_traffic.exceptions import ConfigurationError @@ -19,35 +21,6 @@ T = TypeVar("T", RateLimitConfig, GlobalConfig) # Environment variable prefix for config values ENV_PREFIX = "FASTAPI_TRAFFIC_" -# Mapping of config field names to their types for validation -_RATE_LIMIT_FIELD_TYPES: dict[str, type[Any]] = { - "limit": int, - "window_size": float, - "algorithm": Algorithm, - "key_prefix": str, - "burst_size": int, - "include_headers": bool, - "error_message": str, - "status_code": int, - "skip_on_error": bool, - "cost": int, -} - -_GLOBAL_FIELD_TYPES: dict[str, type[Any]] = { - "enabled": bool, - "default_limit": int, - "default_window_size": float, - "default_algorithm": Algorithm, - "key_prefix": str, - "include_headers": bool, - "error_message": str, - "status_code": int, - "skip_on_error": bool, - "exempt_ips": set, - "exempt_paths": set, - "headers_prefix": str, -} - # Fields that cannot be loaded from config files (callables, complex objects) _NON_LOADABLE_FIELDS: frozenset[str] = frozenset( { @@ -59,6 +32,98 @@ _NON_LOADABLE_FIELDS: frozenset[str] = frozenset( ) +class _RateLimitSchema(BaseModel): + """Pydantic schema for validating rate limit configuration input.""" + + model_config = ConfigDict(extra="forbid") + + limit: int + window_size: float = 60.0 + algorithm: Algorithm = Algorithm.SLIDING_WINDOW_COUNTER + key_prefix: str = "ratelimit" + burst_size: int | None = None + include_headers: bool = True + error_message: str = "Rate limit exceeded" + status_code: int = 429 + skip_on_error: bool = False + cost: int = 1 + + @field_validator("algorithm", mode="before") + @classmethod + def _normalize_algorithm(cls, v: Any) -> Any: + if isinstance(v, str): + return v.lower() + return v + + +class _GlobalSchema(BaseModel): + """Pydantic schema for validating global configuration input.""" + + model_config = ConfigDict(extra="forbid") + + enabled: bool = True + default_limit: int = 100 + default_window_size: float = 60.0 + default_algorithm: Algorithm = Algorithm.SLIDING_WINDOW_COUNTER + key_prefix: str = "fastapi_traffic" + include_headers: bool = True + error_message: str = "Rate limit exceeded. Please try again later." + status_code: int = 429 + skip_on_error: bool = False + exempt_ips: set[str] = set() + exempt_paths: set[str] = set() + headers_prefix: str = "X-RateLimit" + + @field_validator("default_algorithm", mode="before") + @classmethod + def _normalize_algorithm(cls, v: Any) -> Any: + if isinstance(v, str): + return v.lower() + return v + + +# Known field names per schema (used for env-var extraction) +_RATE_LIMIT_FIELDS: frozenset[str] = frozenset(_RateLimitSchema.model_fields.keys()) +_GLOBAL_FIELDS: frozenset[str] = frozenset(_GlobalSchema.model_fields.keys()) + + +def _check_non_loadable(data: Mapping[str, Any]) -> None: + """Raise ConfigurationError if data contains non-loadable fields.""" + for key in data: + if key in _NON_LOADABLE_FIELDS: + msg = f"Field '{key}' cannot be loaded from configuration files" + raise ConfigurationError(msg) + + +def _format_validation_error(exc: ValidationError) -> str: + """Convert a Pydantic ValidationError to a user-friendly message.""" + errors = exc.errors() + if not errors: + return str(exc) + + err = errors[0] + loc = ".".join(str(p) for p in err["loc"]) if err["loc"] else "unknown" + err_type = err["type"] + msg = err["msg"] + ctx = err.get("ctx", {}) + + if err_type == "extra_forbidden": + return f"Unknown configuration field: '{loc}'" + + if err_type in ("int_parsing", "float_parsing"): + input_val = ctx.get("error", err.get("input", "")) + return f"Cannot parse value '{input_val}' as {loc}: {msg}" + + if err_type == "bool_parsing": + return f"Cannot parse value as bool for '{loc}': {msg}" + + if "enum" in err_type or err_type == "value_error": + input_val = err.get("input", "") + return f"Cannot parse value '{input_val}' as {loc}: {msg}" + + return f"Invalid value for '{loc}': {msg}" + + class ConfigLoader: """Loader for rate limiting configuration from various sources. @@ -83,88 +148,6 @@ class ConfigLoader: """ self._env_prefix = env_prefix - def _parse_value(self, value: str, target_type: type[Any]) -> Any: - """Parse a string value to the target type. - - Args: - value: The string value to parse. - target_type: The target type to convert to. - - Returns: - The parsed value. - - Raises: - ConfigurationError: If the value cannot be parsed. - """ - try: - if target_type is bool: - return value.lower() in ("true", "1", "yes", "on") - if target_type is int: - return int(value) - if target_type is float: - return float(value) - if target_type is str: - return value - if target_type is Algorithm: - return Algorithm(value.lower()) - if target_type is set: - # Parse comma-separated values - if not value.strip(): - return set() - return {item.strip() for item in value.split(",") if item.strip()} - except (ValueError, KeyError) as e: - msg = f"Cannot parse value '{value}' as {target_type.__name__}: {e}" - raise ConfigurationError(msg) from e - - msg = f"Unsupported type: {target_type}" - raise ConfigurationError(msg) - - def _validate_and_convert( - self, - data: Mapping[str, Any], - field_types: dict[str, type[Any]], - ) -> dict[str, Any]: - """Validate and convert configuration data. - - Args: - data: Raw configuration data. - field_types: Mapping of field names to their expected types. - - Returns: - Validated and converted configuration dictionary. - - Raises: - ConfigurationError: If validation fails. - """ - result: dict[str, Any] = {} - - for key, value in data.items(): - if key in _NON_LOADABLE_FIELDS: - msg = f"Field '{key}' cannot be loaded from configuration files" - raise ConfigurationError(msg) - - if key not in field_types: - msg = f"Unknown configuration field: '{key}'" - raise ConfigurationError(msg) - - target_type = field_types[key] - - if isinstance(value, str): - result[key] = self._parse_value(value, target_type) - elif target_type is set and isinstance(value, list): - result[key] = set(value) - elif target_type is Algorithm and isinstance(value, str): - result[key] = Algorithm(value.lower()) - elif isinstance(value, target_type): - result[key] = value - elif target_type is float and isinstance(value, int): - result[key] = float(value) - else: - msg = f"Invalid type for '{key}': expected {target_type.__name__}, got {type(value).__name__}" - raise ConfigurationError(msg) - - return result - def _load_dotenv_file(self, file_path: Path) -> dict[str, str]: """Load environment variables from a .env file. @@ -248,14 +231,14 @@ class ConfigLoader: def _extract_env_config( self, prefix: str, - field_types: dict[str, type[Any]], + known_fields: frozenset[str], env_source: Mapping[str, str] | None = None, ) -> dict[str, str]: """Extract configuration from environment variables. Args: prefix: The prefix to look for (e.g., "RATE_LIMIT_" or "GLOBAL_"). - field_types: Mapping of field names to their expected types. + known_fields: Set of known field names. env_source: Optional source of environment variables. Defaults to os.environ. Returns: @@ -268,11 +251,29 @@ class ConfigLoader: for key, value in source.items(): if key.startswith(full_prefix): field_name = key[len(full_prefix) :].lower() - if field_name in field_types: + if field_name in known_fields: result[field_name] = value return result + def _parse_set_from_string(self, data: dict[str, Any]) -> dict[str, Any]: + """Pre-process comma-separated string values into lists for set fields. + + This handles the env-var case where sets are represented as + comma-separated strings (e.g., "127.0.0.1, 10.0.0.1"). + """ + result = dict(data) + for key in ("exempt_ips", "exempt_paths"): + if key in result and isinstance(result[key], str): + value = result[key].strip() + if not value: + result[key] = [] + else: + result[key] = [ + item.strip() for item in value.split(",") if item.strip() + ] + return result + def load_rate_limit_config_from_env( self, env_source: Mapping[str, str] | None = None, @@ -294,13 +295,21 @@ class ConfigLoader: ConfigurationError: If configuration is invalid. """ raw_config = self._extract_env_config( - "RATE_LIMIT_", _RATE_LIMIT_FIELD_TYPES, env_source + "RATE_LIMIT_", _RATE_LIMIT_FIELDS, env_source ) - config_dict = self._validate_and_convert(raw_config, _RATE_LIMIT_FIELD_TYPES) + + _check_non_loadable(raw_config) + + try: + schema = _RateLimitSchema(**raw_config) # type: ignore[arg-type] # Pydantic coerces str→typed values at runtime + except ValidationError as e: + raise ConfigurationError(_format_validation_error(e)) from e + + config_dict = schema.model_dump(exclude_defaults=True) # Apply overrides for key, value in overrides.items(): - if key in _NON_LOADABLE_FIELDS or key in _RATE_LIMIT_FIELD_TYPES: + if key in _NON_LOADABLE_FIELDS or key in _RATE_LIMIT_FIELDS: config_dict[key] = value # Ensure required field 'limit' is present @@ -353,11 +362,19 @@ class ConfigLoader: if not isinstance(raw_config, dict): msg = "JSON root must be an object" raise ConfigurationError(msg) - config_dict = self._validate_and_convert(raw_config, _RATE_LIMIT_FIELD_TYPES) + + _check_non_loadable(raw_config) + + try: + schema = _RateLimitSchema(**raw_config) # type: ignore[arg-type] # Pydantic coerces str→typed values at runtime + except ValidationError as e: + raise ConfigurationError(_format_validation_error(e)) from e + + config_dict = schema.model_dump(exclude_defaults=True) # Apply overrides for key, value in overrides.items(): - if key in _NON_LOADABLE_FIELDS or key in _RATE_LIMIT_FIELD_TYPES: + if key in _NON_LOADABLE_FIELDS or key in _RATE_LIMIT_FIELDS: config_dict[key] = value # Ensure required field 'limit' is present @@ -388,13 +405,24 @@ class ConfigLoader: ConfigurationError: If configuration is invalid. """ raw_config = self._extract_env_config( - "GLOBAL_", _GLOBAL_FIELD_TYPES, env_source + "GLOBAL_", _GLOBAL_FIELDS, env_source ) - config_dict = self._validate_and_convert(raw_config, _GLOBAL_FIELD_TYPES) + + _check_non_loadable(raw_config) + + # Pre-process comma-separated strings into lists for set fields + processed = self._parse_set_from_string(raw_config) + + try: + schema = _GlobalSchema(**processed) # type: ignore[arg-type] # Pydantic coerces str→typed values at runtime + except ValidationError as e: + raise ConfigurationError(_format_validation_error(e)) from e + + config_dict = schema.model_dump(exclude_defaults=True) # Apply overrides for key, value in overrides.items(): - if key in _NON_LOADABLE_FIELDS or key in _GLOBAL_FIELD_TYPES: + if key in _NON_LOADABLE_FIELDS or key in _GLOBAL_FIELDS: config_dict[key] = value return GlobalConfig(**config_dict) @@ -439,11 +467,23 @@ class ConfigLoader: """ path = Path(file_path) raw_config = self._load_json_file(path) - config_dict = self._validate_and_convert(raw_config, _GLOBAL_FIELD_TYPES) + + if not isinstance(raw_config, dict): + msg = "JSON root must be an object" + raise ConfigurationError(msg) + + _check_non_loadable(raw_config) + + try: + schema = _GlobalSchema(**raw_config) + except ValidationError as e: + raise ConfigurationError(_format_validation_error(e)) from e + + config_dict = schema.model_dump(exclude_defaults=True) # Apply overrides for key, value in overrides.items(): - if key in _NON_LOADABLE_FIELDS or key in _GLOBAL_FIELD_TYPES: + if key in _NON_LOADABLE_FIELDS or key in _GLOBAL_FIELDS: config_dict[key] = value return GlobalConfig(**config_dict) diff --git a/fastapi_traffic/core/decorator.py b/fastapi_traffic/core/decorator.py index 4fd8bd7..a582858 100644 --- a/fastapi_traffic/core/decorator.py +++ b/fastapi_traffic/core/decorator.py @@ -51,7 +51,6 @@ def rate_limit( /, ) -> Callable[[F], F]: ... - def rate_limit( limit: int, window_size: float = 60.0, @@ -139,9 +138,23 @@ def rate_limit( def sync_wrapper(*args: Any, **kwargs: Any) -> Any: import asyncio - return asyncio.get_event_loop().run_until_complete( - async_wrapper(*args, **kwargs) - ) + async def _sync_rate_limit() -> Any: + request = _extract_request(args, kwargs) + if request is None: + return func(*args, **kwargs) + + limiter = get_limiter() + result = await limiter.hit(request, config) + + response = func(*args, **kwargs) + + if config.include_headers and hasattr(response, "headers"): + for key, value in result.info.to_headers().items(): + response.headers[key] = value + + return response + + return asyncio.get_event_loop().run_until_complete(_sync_rate_limit()) if _is_coroutine_function(func): return async_wrapper # type: ignore[return-value] diff --git a/main.py b/main.py deleted file mode 100644 index 85fe1fe..0000000 --- a/main.py +++ /dev/null @@ -1,6 +0,0 @@ -def main(): - print("Hello from fastapi-traffic!") - - -if __name__ == "__main__": - main() diff --git a/pyproject.toml b/pyproject.toml index e322272..a13939d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ classifiers = [ "Typing :: Typed", ] dependencies = [ + "pydantic>=2.0", "starlette>=0.27.0", ] @@ -43,7 +44,7 @@ dev = [ [project.urls] Documentation = "https://gitlab.com/zanewalker/fastapi-traffic#readme" -Repository = "https://github.com/zanewalker/fastapi-traffic" +Repository = "https://gitlab.com/zanewalker/fastapi-traffic" Issues = "https://gitlab.com/zanewalker/fastapi-traffic/issues" [build-system] @@ -108,8 +109,8 @@ reportUnknownVariableType = false reportUnknownParameterType = false reportMissingImports = false reportUnusedFunction = false -reportInvalidTypeArguments = false -reportGeneralTypeIssues = false +reportInvalidTypeArguments = true +reportGeneralTypeIssues = true [tool.pytest.ini_options] asyncio_mode = "auto" diff --git a/tests/test_config_loader.py b/tests/test_config_loader.py index 0d21651..1d4064f 100644 --- a/tests/test_config_loader.py +++ b/tests/test_config_loader.py @@ -365,7 +365,7 @@ class TestConfigLoaderValidation: "FASTAPI_TRAFFIC_RATE_LIMIT_ALGORITHM": "invalid_algorithm", } - with pytest.raises(ConfigurationError, match="Cannot parse value"): + with pytest.raises(ConfigurationError): loader.load_rate_limit_config_from_env(env_vars) def test_invalid_int_value(self, loader: ConfigLoader) -> None: @@ -374,7 +374,7 @@ class TestConfigLoaderValidation: "FASTAPI_TRAFFIC_RATE_LIMIT_LIMIT": "not_a_number", } - with pytest.raises(ConfigurationError, match="Cannot parse value"): + with pytest.raises(ConfigurationError): loader.load_rate_limit_config_from_env(env_vars) def test_invalid_float_value(self, loader: ConfigLoader) -> None: @@ -384,7 +384,7 @@ class TestConfigLoaderValidation: "FASTAPI_TRAFFIC_RATE_LIMIT_WINDOW_SIZE": "not_a_float", } - with pytest.raises(ConfigurationError, match="Cannot parse value"): + with pytest.raises(ConfigurationError): loader.load_rate_limit_config_from_env(env_vars) def test_unknown_field(self, loader: ConfigLoader, temp_dir: Path) -> None: @@ -411,7 +411,7 @@ class TestConfigLoaderValidation: config_data = {"limit": "not_an_int"} json_file.write_text(json.dumps(config_data)) - with pytest.raises(ConfigurationError, match="Cannot parse value"): + with pytest.raises(ConfigurationError): loader.load_rate_limit_config_from_json(json_file) def test_bool_parsing_variations(self, loader: ConfigLoader) -> None: diff --git a/uv.lock b/uv.lock index 76a40f1..a419b82 100644 --- a/uv.lock +++ b/uv.lock @@ -259,9 +259,10 @@ wheels = [ [[package]] name = "fastapi-traffic" -version = "0.1.0" +version = "0.2.0" source = { editable = "." } dependencies = [ + { name = "pydantic" }, { name = "starlette" }, ] @@ -304,6 +305,7 @@ requires-dist = [ { name = "fastapi", marker = "extra == 'dev'", specifier = ">=0.100.0" }, { name = "fastapi", marker = "extra == 'fastapi'", specifier = ">=0.100.0" }, { name = "httpx", marker = "extra == 'dev'", specifier = ">=0.27.0" }, + { name = "pydantic", specifier = ">=2.0" }, { name = "pyright", marker = "extra == 'dev'", specifier = ">=1.1.350" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" }, { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.23.0" },