Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 492410614f | |||
| 4f19c0b19e | |||
| fe07912040 | |||
| 6bdeab2b4e |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -13,4 +13,4 @@ things-todo.md
|
|||||||
.ruff_cache
|
.ruff_cache
|
||||||
.qodo
|
.qodo
|
||||||
.pytest_cache
|
.pytest_cache
|
||||||
.vscode
|
.vscode/
|
||||||
@@ -95,18 +95,6 @@ build-package:
|
|||||||
- if: $CI_COMMIT_TAG
|
- if: $CI_COMMIT_TAG
|
||||||
- if: $CI_COMMIT_BRANCH == $CI_DEFAULT_BRANCH
|
- if: $CI_COMMIT_BRANCH == $CI_DEFAULT_BRANCH
|
||||||
|
|
||||||
# Publish to PyPI (only on tags)
|
|
||||||
publish-pypi:
|
|
||||||
extends: .python-base
|
|
||||||
stage: publish
|
|
||||||
script:
|
|
||||||
- uv publish --token $PYPI_TOKEN
|
|
||||||
rules:
|
|
||||||
- if: $CI_COMMIT_TAG =~ /^v\d+\.\d+\.\d+$/
|
|
||||||
when: manual
|
|
||||||
needs:
|
|
||||||
- build-package
|
|
||||||
|
|
||||||
# Publish to GitLab Package Registry
|
# Publish to GitLab Package Registry
|
||||||
publish-gitlab:
|
publish-gitlab:
|
||||||
extends: .python-base
|
extends: .python-base
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ Want to contribute or just poke around? Here's how to get set up.
|
|||||||
### Using uv (the fast way)
|
### Using uv (the fast way)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone https://gitlab.com/bereckobrian/fastapi-traffic.git
|
git clone https://gitlab.com/zanewalker/fastapi-traffic.git
|
||||||
cd fastapi-traffic
|
cd fastapi-traffic
|
||||||
|
|
||||||
# This creates a venv and installs everything
|
# This creates a venv and installs everything
|
||||||
@@ -25,7 +25,7 @@ That's it. uv figures out the rest.
|
|||||||
### Using pip
|
### Using pip
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone https://gitlab.com/bereckobrian/fastapi-traffic.git
|
git clone https://gitlab.com/zanewalker/fastapi-traffic.git
|
||||||
cd fastapi-traffic
|
cd fastapi-traffic
|
||||||
|
|
||||||
python -m venv .venv
|
python -m venv .venv
|
||||||
|
|||||||
12
README.md
12
README.md
@@ -18,26 +18,26 @@ Most rate limiting solutions are either too simple (fixed window only) or too co
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Basic installation (memory backend only)
|
# Basic installation (memory backend only)
|
||||||
pip install fastapi-traffic
|
pip install git+https://gitlab.com/zanewalker/fastapi-traffic.git
|
||||||
|
|
||||||
# With Redis support
|
# With Redis support
|
||||||
pip install fastapi-traffic[redis]
|
pip install git+https://gitlab.com/zanewalker/fastapi-traffic.git[redis]
|
||||||
|
|
||||||
# With all extras
|
# With all extras
|
||||||
pip install fastapi-traffic[all]
|
pip install git+https://gitlab.com/zanewalker/fastapi-traffic.git[all]
|
||||||
```
|
```
|
||||||
|
|
||||||
### Using uv
|
### Using uv
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Basic installation
|
# Basic installation
|
||||||
uv add fastapi-traffic
|
uv add git+https://gitlab.com/zanewalker/fastapi-traffic.git
|
||||||
|
|
||||||
# With Redis support
|
# With Redis support
|
||||||
uv add fastapi-traffic[redis]
|
uv add git+https://gitlab.com/zanewalker/fastapi-traffic.git[redis]
|
||||||
|
|
||||||
# With all extras
|
# With all extras
|
||||||
uv add fastapi-traffic[all]
|
uv add git+https://gitlab.com/zanewalker/fastapi-traffic.git[all]
|
||||||
```
|
```
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from fastapi_traffic.exceptions import (
|
|||||||
RateLimitExceeded,
|
RateLimitExceeded,
|
||||||
)
|
)
|
||||||
|
|
||||||
__version__ = "0.1.0"
|
__version__ = "0.2.0"
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Algorithm",
|
"Algorithm",
|
||||||
"Backend",
|
"Backend",
|
||||||
|
|||||||
@@ -89,6 +89,8 @@ class TokenBucketAlgorithm(BaseAlgorithm):
|
|||||||
remaining=int(tokens),
|
remaining=int(tokens),
|
||||||
reset_at=now + self.window_size,
|
reset_at=now + self.window_size,
|
||||||
window_size=self.window_size,
|
window_size=self.window_size,
|
||||||
|
retry_after = (1 - tokens) / self.refill_rate
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
tokens = float(state.get("tokens", self.burst_size))
|
tokens = float(state.get("tokens", self.burst_size))
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any, TypeAlias
|
||||||
|
|
||||||
from fastapi_traffic.core.algorithms import Algorithm
|
from fastapi_traffic.core.algorithms import Algorithm
|
||||||
|
|
||||||
@@ -14,7 +14,7 @@ if TYPE_CHECKING:
|
|||||||
from fastapi_traffic.backends.base import Backend
|
from fastapi_traffic.backends.base import Backend
|
||||||
|
|
||||||
|
|
||||||
KeyExtractor = Callable[["Request"], str]
|
KeyExtractor: TypeAlias = Callable[["Request"], str]
|
||||||
|
|
||||||
|
|
||||||
def default_key_extractor(request: Request) -> str:
|
def default_key_extractor(request: Request) -> str:
|
||||||
@@ -55,10 +55,10 @@ class RateLimitConfig:
|
|||||||
if self.limit <= 0:
|
if self.limit <= 0:
|
||||||
msg = "limit must be positive"
|
msg = "limit must be positive"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
if self.window_size <= 0:
|
elif self.window_size <= 0:
|
||||||
msg = "window_size must be positive"
|
msg = "window_size must be positive"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
if self.cost <= 0:
|
elif self.cost <= 0:
|
||||||
msg = "cost must be positive"
|
msg = "cost must be positive"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, TypeVar
|
from typing import TYPE_CHECKING, Any, TypeVar
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, ValidationError, field_validator
|
||||||
|
|
||||||
from fastapi_traffic.core.algorithms import Algorithm
|
from fastapi_traffic.core.algorithms import Algorithm
|
||||||
from fastapi_traffic.core.config import GlobalConfig, RateLimitConfig
|
from fastapi_traffic.core.config import GlobalConfig, RateLimitConfig
|
||||||
from fastapi_traffic.exceptions import ConfigurationError
|
from fastapi_traffic.exceptions import ConfigurationError
|
||||||
@@ -19,35 +21,6 @@ T = TypeVar("T", RateLimitConfig, GlobalConfig)
|
|||||||
# Environment variable prefix for config values
|
# Environment variable prefix for config values
|
||||||
ENV_PREFIX = "FASTAPI_TRAFFIC_"
|
ENV_PREFIX = "FASTAPI_TRAFFIC_"
|
||||||
|
|
||||||
# Mapping of config field names to their types for validation
|
|
||||||
_RATE_LIMIT_FIELD_TYPES: dict[str, type[Any]] = {
|
|
||||||
"limit": int,
|
|
||||||
"window_size": float,
|
|
||||||
"algorithm": Algorithm,
|
|
||||||
"key_prefix": str,
|
|
||||||
"burst_size": int,
|
|
||||||
"include_headers": bool,
|
|
||||||
"error_message": str,
|
|
||||||
"status_code": int,
|
|
||||||
"skip_on_error": bool,
|
|
||||||
"cost": int,
|
|
||||||
}
|
|
||||||
|
|
||||||
_GLOBAL_FIELD_TYPES: dict[str, type[Any]] = {
|
|
||||||
"enabled": bool,
|
|
||||||
"default_limit": int,
|
|
||||||
"default_window_size": float,
|
|
||||||
"default_algorithm": Algorithm,
|
|
||||||
"key_prefix": str,
|
|
||||||
"include_headers": bool,
|
|
||||||
"error_message": str,
|
|
||||||
"status_code": int,
|
|
||||||
"skip_on_error": bool,
|
|
||||||
"exempt_ips": set,
|
|
||||||
"exempt_paths": set,
|
|
||||||
"headers_prefix": str,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Fields that cannot be loaded from config files (callables, complex objects)
|
# Fields that cannot be loaded from config files (callables, complex objects)
|
||||||
_NON_LOADABLE_FIELDS: frozenset[str] = frozenset(
|
_NON_LOADABLE_FIELDS: frozenset[str] = frozenset(
|
||||||
{
|
{
|
||||||
@@ -59,6 +32,98 @@ _NON_LOADABLE_FIELDS: frozenset[str] = frozenset(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _RateLimitSchema(BaseModel):
|
||||||
|
"""Pydantic schema for validating rate limit configuration input."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
limit: int
|
||||||
|
window_size: float = 60.0
|
||||||
|
algorithm: Algorithm = Algorithm.SLIDING_WINDOW_COUNTER
|
||||||
|
key_prefix: str = "ratelimit"
|
||||||
|
burst_size: int | None = None
|
||||||
|
include_headers: bool = True
|
||||||
|
error_message: str = "Rate limit exceeded"
|
||||||
|
status_code: int = 429
|
||||||
|
skip_on_error: bool = False
|
||||||
|
cost: int = 1
|
||||||
|
|
||||||
|
@field_validator("algorithm", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _normalize_algorithm(cls, v: Any) -> Any:
|
||||||
|
if isinstance(v, str):
|
||||||
|
return v.lower()
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class _GlobalSchema(BaseModel):
|
||||||
|
"""Pydantic schema for validating global configuration input."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
enabled: bool = True
|
||||||
|
default_limit: int = 100
|
||||||
|
default_window_size: float = 60.0
|
||||||
|
default_algorithm: Algorithm = Algorithm.SLIDING_WINDOW_COUNTER
|
||||||
|
key_prefix: str = "fastapi_traffic"
|
||||||
|
include_headers: bool = True
|
||||||
|
error_message: str = "Rate limit exceeded. Please try again later."
|
||||||
|
status_code: int = 429
|
||||||
|
skip_on_error: bool = False
|
||||||
|
exempt_ips: set[str] = set()
|
||||||
|
exempt_paths: set[str] = set()
|
||||||
|
headers_prefix: str = "X-RateLimit"
|
||||||
|
|
||||||
|
@field_validator("default_algorithm", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _normalize_algorithm(cls, v: Any) -> Any:
|
||||||
|
if isinstance(v, str):
|
||||||
|
return v.lower()
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
# Known field names per schema (used for env-var extraction)
|
||||||
|
_RATE_LIMIT_FIELDS: frozenset[str] = frozenset(_RateLimitSchema.model_fields.keys())
|
||||||
|
_GLOBAL_FIELDS: frozenset[str] = frozenset(_GlobalSchema.model_fields.keys())
|
||||||
|
|
||||||
|
|
||||||
|
def _check_non_loadable(data: Mapping[str, Any]) -> None:
|
||||||
|
"""Raise ConfigurationError if data contains non-loadable fields."""
|
||||||
|
for key in data:
|
||||||
|
if key in _NON_LOADABLE_FIELDS:
|
||||||
|
msg = f"Field '{key}' cannot be loaded from configuration files"
|
||||||
|
raise ConfigurationError(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def _format_validation_error(exc: ValidationError) -> str:
|
||||||
|
"""Convert a Pydantic ValidationError to a user-friendly message."""
|
||||||
|
errors = exc.errors()
|
||||||
|
if not errors:
|
||||||
|
return str(exc)
|
||||||
|
|
||||||
|
err = errors[0]
|
||||||
|
loc = ".".join(str(p) for p in err["loc"]) if err["loc"] else "unknown"
|
||||||
|
err_type = err["type"]
|
||||||
|
msg = err["msg"]
|
||||||
|
ctx = err.get("ctx", {})
|
||||||
|
|
||||||
|
if err_type == "extra_forbidden":
|
||||||
|
return f"Unknown configuration field: '{loc}'"
|
||||||
|
|
||||||
|
if err_type in ("int_parsing", "float_parsing"):
|
||||||
|
input_val = ctx.get("error", err.get("input", ""))
|
||||||
|
return f"Cannot parse value '{input_val}' as {loc}: {msg}"
|
||||||
|
|
||||||
|
if err_type == "bool_parsing":
|
||||||
|
return f"Cannot parse value as bool for '{loc}': {msg}"
|
||||||
|
|
||||||
|
if "enum" in err_type or err_type == "value_error":
|
||||||
|
input_val = err.get("input", "")
|
||||||
|
return f"Cannot parse value '{input_val}' as {loc}: {msg}"
|
||||||
|
|
||||||
|
return f"Invalid value for '{loc}': {msg}"
|
||||||
|
|
||||||
|
|
||||||
class ConfigLoader:
|
class ConfigLoader:
|
||||||
"""Loader for rate limiting configuration from various sources.
|
"""Loader for rate limiting configuration from various sources.
|
||||||
|
|
||||||
@@ -83,88 +148,6 @@ class ConfigLoader:
|
|||||||
"""
|
"""
|
||||||
self._env_prefix = env_prefix
|
self._env_prefix = env_prefix
|
||||||
|
|
||||||
def _parse_value(self, value: str, target_type: type[Any]) -> Any:
|
|
||||||
"""Parse a string value to the target type.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
value: The string value to parse.
|
|
||||||
target_type: The target type to convert to.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The parsed value.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ConfigurationError: If the value cannot be parsed.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if target_type is bool:
|
|
||||||
return value.lower() in ("true", "1", "yes", "on")
|
|
||||||
if target_type is int:
|
|
||||||
return int(value)
|
|
||||||
if target_type is float:
|
|
||||||
return float(value)
|
|
||||||
if target_type is str:
|
|
||||||
return value
|
|
||||||
if target_type is Algorithm:
|
|
||||||
return Algorithm(value.lower())
|
|
||||||
if target_type is set:
|
|
||||||
# Parse comma-separated values
|
|
||||||
if not value.strip():
|
|
||||||
return set()
|
|
||||||
return {item.strip() for item in value.split(",") if item.strip()}
|
|
||||||
except (ValueError, KeyError) as e:
|
|
||||||
msg = f"Cannot parse value '{value}' as {target_type.__name__}: {e}"
|
|
||||||
raise ConfigurationError(msg) from e
|
|
||||||
|
|
||||||
msg = f"Unsupported type: {target_type}"
|
|
||||||
raise ConfigurationError(msg)
|
|
||||||
|
|
||||||
def _validate_and_convert(
|
|
||||||
self,
|
|
||||||
data: Mapping[str, Any],
|
|
||||||
field_types: dict[str, type[Any]],
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Validate and convert configuration data.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data: Raw configuration data.
|
|
||||||
field_types: Mapping of field names to their expected types.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Validated and converted configuration dictionary.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ConfigurationError: If validation fails.
|
|
||||||
"""
|
|
||||||
result: dict[str, Any] = {}
|
|
||||||
|
|
||||||
for key, value in data.items():
|
|
||||||
if key in _NON_LOADABLE_FIELDS:
|
|
||||||
msg = f"Field '{key}' cannot be loaded from configuration files"
|
|
||||||
raise ConfigurationError(msg)
|
|
||||||
|
|
||||||
if key not in field_types:
|
|
||||||
msg = f"Unknown configuration field: '{key}'"
|
|
||||||
raise ConfigurationError(msg)
|
|
||||||
|
|
||||||
target_type = field_types[key]
|
|
||||||
|
|
||||||
if isinstance(value, str):
|
|
||||||
result[key] = self._parse_value(value, target_type)
|
|
||||||
elif target_type is set and isinstance(value, list):
|
|
||||||
result[key] = set(value)
|
|
||||||
elif target_type is Algorithm and isinstance(value, str):
|
|
||||||
result[key] = Algorithm(value.lower())
|
|
||||||
elif isinstance(value, target_type):
|
|
||||||
result[key] = value
|
|
||||||
elif target_type is float and isinstance(value, int):
|
|
||||||
result[key] = float(value)
|
|
||||||
else:
|
|
||||||
msg = f"Invalid type for '{key}': expected {target_type.__name__}, got {type(value).__name__}"
|
|
||||||
raise ConfigurationError(msg)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _load_dotenv_file(self, file_path: Path) -> dict[str, str]:
|
def _load_dotenv_file(self, file_path: Path) -> dict[str, str]:
|
||||||
"""Load environment variables from a .env file.
|
"""Load environment variables from a .env file.
|
||||||
|
|
||||||
@@ -248,14 +231,14 @@ class ConfigLoader:
|
|||||||
def _extract_env_config(
|
def _extract_env_config(
|
||||||
self,
|
self,
|
||||||
prefix: str,
|
prefix: str,
|
||||||
field_types: dict[str, type[Any]],
|
known_fields: frozenset[str],
|
||||||
env_source: Mapping[str, str] | None = None,
|
env_source: Mapping[str, str] | None = None,
|
||||||
) -> dict[str, str]:
|
) -> dict[str, str]:
|
||||||
"""Extract configuration from environment variables.
|
"""Extract configuration from environment variables.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prefix: The prefix to look for (e.g., "RATE_LIMIT_" or "GLOBAL_").
|
prefix: The prefix to look for (e.g., "RATE_LIMIT_" or "GLOBAL_").
|
||||||
field_types: Mapping of field names to their expected types.
|
known_fields: Set of known field names.
|
||||||
env_source: Optional source of environment variables. Defaults to os.environ.
|
env_source: Optional source of environment variables. Defaults to os.environ.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -268,11 +251,29 @@ class ConfigLoader:
|
|||||||
for key, value in source.items():
|
for key, value in source.items():
|
||||||
if key.startswith(full_prefix):
|
if key.startswith(full_prefix):
|
||||||
field_name = key[len(full_prefix) :].lower()
|
field_name = key[len(full_prefix) :].lower()
|
||||||
if field_name in field_types:
|
if field_name in known_fields:
|
||||||
result[field_name] = value
|
result[field_name] = value
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def _parse_set_from_string(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Pre-process comma-separated string values into lists for set fields.
|
||||||
|
|
||||||
|
This handles the env-var case where sets are represented as
|
||||||
|
comma-separated strings (e.g., "127.0.0.1, 10.0.0.1").
|
||||||
|
"""
|
||||||
|
result = dict(data)
|
||||||
|
for key in ("exempt_ips", "exempt_paths"):
|
||||||
|
if key in result and isinstance(result[key], str):
|
||||||
|
value = result[key].strip()
|
||||||
|
if not value:
|
||||||
|
result[key] = []
|
||||||
|
else:
|
||||||
|
result[key] = [
|
||||||
|
item.strip() for item in value.split(",") if item.strip()
|
||||||
|
]
|
||||||
|
return result
|
||||||
|
|
||||||
def load_rate_limit_config_from_env(
|
def load_rate_limit_config_from_env(
|
||||||
self,
|
self,
|
||||||
env_source: Mapping[str, str] | None = None,
|
env_source: Mapping[str, str] | None = None,
|
||||||
@@ -294,13 +295,21 @@ class ConfigLoader:
|
|||||||
ConfigurationError: If configuration is invalid.
|
ConfigurationError: If configuration is invalid.
|
||||||
"""
|
"""
|
||||||
raw_config = self._extract_env_config(
|
raw_config = self._extract_env_config(
|
||||||
"RATE_LIMIT_", _RATE_LIMIT_FIELD_TYPES, env_source
|
"RATE_LIMIT_", _RATE_LIMIT_FIELDS, env_source
|
||||||
)
|
)
|
||||||
config_dict = self._validate_and_convert(raw_config, _RATE_LIMIT_FIELD_TYPES)
|
|
||||||
|
_check_non_loadable(raw_config)
|
||||||
|
|
||||||
|
try:
|
||||||
|
schema = _RateLimitSchema(**raw_config) # type: ignore[arg-type] # Pydantic coerces str→typed values at runtime
|
||||||
|
except ValidationError as e:
|
||||||
|
raise ConfigurationError(_format_validation_error(e)) from e
|
||||||
|
|
||||||
|
config_dict = schema.model_dump(exclude_defaults=True)
|
||||||
|
|
||||||
# Apply overrides
|
# Apply overrides
|
||||||
for key, value in overrides.items():
|
for key, value in overrides.items():
|
||||||
if key in _NON_LOADABLE_FIELDS or key in _RATE_LIMIT_FIELD_TYPES:
|
if key in _NON_LOADABLE_FIELDS or key in _RATE_LIMIT_FIELDS:
|
||||||
config_dict[key] = value
|
config_dict[key] = value
|
||||||
|
|
||||||
# Ensure required field 'limit' is present
|
# Ensure required field 'limit' is present
|
||||||
@@ -353,11 +362,19 @@ class ConfigLoader:
|
|||||||
if not isinstance(raw_config, dict):
|
if not isinstance(raw_config, dict):
|
||||||
msg = "JSON root must be an object"
|
msg = "JSON root must be an object"
|
||||||
raise ConfigurationError(msg)
|
raise ConfigurationError(msg)
|
||||||
config_dict = self._validate_and_convert(raw_config, _RATE_LIMIT_FIELD_TYPES)
|
|
||||||
|
_check_non_loadable(raw_config)
|
||||||
|
|
||||||
|
try:
|
||||||
|
schema = _RateLimitSchema(**raw_config) # type: ignore[arg-type] # Pydantic coerces str→typed values at runtime
|
||||||
|
except ValidationError as e:
|
||||||
|
raise ConfigurationError(_format_validation_error(e)) from e
|
||||||
|
|
||||||
|
config_dict = schema.model_dump(exclude_defaults=True)
|
||||||
|
|
||||||
# Apply overrides
|
# Apply overrides
|
||||||
for key, value in overrides.items():
|
for key, value in overrides.items():
|
||||||
if key in _NON_LOADABLE_FIELDS or key in _RATE_LIMIT_FIELD_TYPES:
|
if key in _NON_LOADABLE_FIELDS or key in _RATE_LIMIT_FIELDS:
|
||||||
config_dict[key] = value
|
config_dict[key] = value
|
||||||
|
|
||||||
# Ensure required field 'limit' is present
|
# Ensure required field 'limit' is present
|
||||||
@@ -388,13 +405,24 @@ class ConfigLoader:
|
|||||||
ConfigurationError: If configuration is invalid.
|
ConfigurationError: If configuration is invalid.
|
||||||
"""
|
"""
|
||||||
raw_config = self._extract_env_config(
|
raw_config = self._extract_env_config(
|
||||||
"GLOBAL_", _GLOBAL_FIELD_TYPES, env_source
|
"GLOBAL_", _GLOBAL_FIELDS, env_source
|
||||||
)
|
)
|
||||||
config_dict = self._validate_and_convert(raw_config, _GLOBAL_FIELD_TYPES)
|
|
||||||
|
_check_non_loadable(raw_config)
|
||||||
|
|
||||||
|
# Pre-process comma-separated strings into lists for set fields
|
||||||
|
processed = self._parse_set_from_string(raw_config)
|
||||||
|
|
||||||
|
try:
|
||||||
|
schema = _GlobalSchema(**processed) # type: ignore[arg-type] # Pydantic coerces str→typed values at runtime
|
||||||
|
except ValidationError as e:
|
||||||
|
raise ConfigurationError(_format_validation_error(e)) from e
|
||||||
|
|
||||||
|
config_dict = schema.model_dump(exclude_defaults=True)
|
||||||
|
|
||||||
# Apply overrides
|
# Apply overrides
|
||||||
for key, value in overrides.items():
|
for key, value in overrides.items():
|
||||||
if key in _NON_LOADABLE_FIELDS or key in _GLOBAL_FIELD_TYPES:
|
if key in _NON_LOADABLE_FIELDS or key in _GLOBAL_FIELDS:
|
||||||
config_dict[key] = value
|
config_dict[key] = value
|
||||||
|
|
||||||
return GlobalConfig(**config_dict)
|
return GlobalConfig(**config_dict)
|
||||||
@@ -439,11 +467,23 @@ class ConfigLoader:
|
|||||||
"""
|
"""
|
||||||
path = Path(file_path)
|
path = Path(file_path)
|
||||||
raw_config = self._load_json_file(path)
|
raw_config = self._load_json_file(path)
|
||||||
config_dict = self._validate_and_convert(raw_config, _GLOBAL_FIELD_TYPES)
|
|
||||||
|
if not isinstance(raw_config, dict):
|
||||||
|
msg = "JSON root must be an object"
|
||||||
|
raise ConfigurationError(msg)
|
||||||
|
|
||||||
|
_check_non_loadable(raw_config)
|
||||||
|
|
||||||
|
try:
|
||||||
|
schema = _GlobalSchema(**raw_config)
|
||||||
|
except ValidationError as e:
|
||||||
|
raise ConfigurationError(_format_validation_error(e)) from e
|
||||||
|
|
||||||
|
config_dict = schema.model_dump(exclude_defaults=True)
|
||||||
|
|
||||||
# Apply overrides
|
# Apply overrides
|
||||||
for key, value in overrides.items():
|
for key, value in overrides.items():
|
||||||
if key in _NON_LOADABLE_FIELDS or key in _GLOBAL_FIELD_TYPES:
|
if key in _NON_LOADABLE_FIELDS or key in _GLOBAL_FIELDS:
|
||||||
config_dict[key] = value
|
config_dict[key] = value
|
||||||
|
|
||||||
return GlobalConfig(**config_dict)
|
return GlobalConfig(**config_dict)
|
||||||
|
|||||||
@@ -51,7 +51,6 @@ def rate_limit(
|
|||||||
/,
|
/,
|
||||||
) -> Callable[[F], F]: ...
|
) -> Callable[[F], F]: ...
|
||||||
|
|
||||||
|
|
||||||
def rate_limit(
|
def rate_limit(
|
||||||
limit: int,
|
limit: int,
|
||||||
window_size: float = 60.0,
|
window_size: float = 60.0,
|
||||||
@@ -139,9 +138,23 @@ def rate_limit(
|
|||||||
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
return asyncio.get_event_loop().run_until_complete(
|
async def _sync_rate_limit() -> Any:
|
||||||
async_wrapper(*args, **kwargs)
|
request = _extract_request(args, kwargs)
|
||||||
)
|
if request is None:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
limiter = get_limiter()
|
||||||
|
result = await limiter.hit(request, config)
|
||||||
|
|
||||||
|
response = func(*args, **kwargs)
|
||||||
|
|
||||||
|
if config.include_headers and hasattr(response, "headers"):
|
||||||
|
for key, value in result.info.to_headers().items():
|
||||||
|
response.headers[key] = value
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
return asyncio.get_event_loop().run_until_complete(_sync_rate_limit())
|
||||||
|
|
||||||
if _is_coroutine_function(func):
|
if _is_coroutine_function(func):
|
||||||
return async_wrapper # type: ignore[return-value]
|
return async_wrapper # type: ignore[return-value]
|
||||||
|
|||||||
6
main.py
6
main.py
@@ -1,6 +0,0 @@
|
|||||||
def main():
|
|
||||||
print("Hello from fastapi-traffic!")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "fastapi-traffic"
|
name = "fastapi-traffic"
|
||||||
version = "0.1.0"
|
version = "0.2.0"
|
||||||
description = "Production-grade rate limiting for FastAPI with multiple algorithms and backends"
|
description = "Production-grade rate limiting for FastAPI with multiple algorithms and backends"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
@@ -22,6 +22,7 @@ classifiers = [
|
|||||||
"Typing :: Typed",
|
"Typing :: Typed",
|
||||||
]
|
]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"pydantic>=2.0",
|
||||||
"starlette>=0.27.0",
|
"starlette>=0.27.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -42,9 +43,9 @@ dev = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
Documentation = "https://gitlab.com/fastapi-traffic/fastapi-traffic#readme"
|
Documentation = "https://gitlab.com/zanewalker/fastapi-traffic#readme"
|
||||||
Repository = "https://github.com/fastapi-traffic/fastapi-traffic"
|
Repository = "https://gitlab.com/zanewalker/fastapi-traffic"
|
||||||
Issues = "https://gitlab.com/bereckobrian/fastapi-traffic/issues"
|
Issues = "https://gitlab.com/zanewalker/fastapi-traffic/issues"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["hatchling"]
|
requires = ["hatchling"]
|
||||||
@@ -108,8 +109,8 @@ reportUnknownVariableType = false
|
|||||||
reportUnknownParameterType = false
|
reportUnknownParameterType = false
|
||||||
reportMissingImports = false
|
reportMissingImports = false
|
||||||
reportUnusedFunction = false
|
reportUnusedFunction = false
|
||||||
reportInvalidTypeArguments = false
|
reportInvalidTypeArguments = true
|
||||||
reportGeneralTypeIssues = false
|
reportGeneralTypeIssues = true
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
asyncio_mode = "auto"
|
asyncio_mode = "auto"
|
||||||
|
|||||||
@@ -365,7 +365,7 @@ class TestConfigLoaderValidation:
|
|||||||
"FASTAPI_TRAFFIC_RATE_LIMIT_ALGORITHM": "invalid_algorithm",
|
"FASTAPI_TRAFFIC_RATE_LIMIT_ALGORITHM": "invalid_algorithm",
|
||||||
}
|
}
|
||||||
|
|
||||||
with pytest.raises(ConfigurationError, match="Cannot parse value"):
|
with pytest.raises(ConfigurationError):
|
||||||
loader.load_rate_limit_config_from_env(env_vars)
|
loader.load_rate_limit_config_from_env(env_vars)
|
||||||
|
|
||||||
def test_invalid_int_value(self, loader: ConfigLoader) -> None:
|
def test_invalid_int_value(self, loader: ConfigLoader) -> None:
|
||||||
@@ -374,7 +374,7 @@ class TestConfigLoaderValidation:
|
|||||||
"FASTAPI_TRAFFIC_RATE_LIMIT_LIMIT": "not_a_number",
|
"FASTAPI_TRAFFIC_RATE_LIMIT_LIMIT": "not_a_number",
|
||||||
}
|
}
|
||||||
|
|
||||||
with pytest.raises(ConfigurationError, match="Cannot parse value"):
|
with pytest.raises(ConfigurationError):
|
||||||
loader.load_rate_limit_config_from_env(env_vars)
|
loader.load_rate_limit_config_from_env(env_vars)
|
||||||
|
|
||||||
def test_invalid_float_value(self, loader: ConfigLoader) -> None:
|
def test_invalid_float_value(self, loader: ConfigLoader) -> None:
|
||||||
@@ -384,7 +384,7 @@ class TestConfigLoaderValidation:
|
|||||||
"FASTAPI_TRAFFIC_RATE_LIMIT_WINDOW_SIZE": "not_a_float",
|
"FASTAPI_TRAFFIC_RATE_LIMIT_WINDOW_SIZE": "not_a_float",
|
||||||
}
|
}
|
||||||
|
|
||||||
with pytest.raises(ConfigurationError, match="Cannot parse value"):
|
with pytest.raises(ConfigurationError):
|
||||||
loader.load_rate_limit_config_from_env(env_vars)
|
loader.load_rate_limit_config_from_env(env_vars)
|
||||||
|
|
||||||
def test_unknown_field(self, loader: ConfigLoader, temp_dir: Path) -> None:
|
def test_unknown_field(self, loader: ConfigLoader, temp_dir: Path) -> None:
|
||||||
@@ -411,7 +411,7 @@ class TestConfigLoaderValidation:
|
|||||||
config_data = {"limit": "not_an_int"}
|
config_data = {"limit": "not_an_int"}
|
||||||
json_file.write_text(json.dumps(config_data))
|
json_file.write_text(json.dumps(config_data))
|
||||||
|
|
||||||
with pytest.raises(ConfigurationError, match="Cannot parse value"):
|
with pytest.raises(ConfigurationError):
|
||||||
loader.load_rate_limit_config_from_json(json_file)
|
loader.load_rate_limit_config_from_json(json_file)
|
||||||
|
|
||||||
def test_bool_parsing_variations(self, loader: ConfigLoader) -> None:
|
def test_bool_parsing_variations(self, loader: ConfigLoader) -> None:
|
||||||
|
|||||||
4
uv.lock
generated
4
uv.lock
generated
@@ -259,9 +259,10 @@ wheels = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fastapi-traffic"
|
name = "fastapi-traffic"
|
||||||
version = "0.1.0"
|
version = "0.2.0"
|
||||||
source = { editable = "." }
|
source = { editable = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
{ name = "pydantic" },
|
||||||
{ name = "starlette" },
|
{ name = "starlette" },
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -304,6 +305,7 @@ requires-dist = [
|
|||||||
{ name = "fastapi", marker = "extra == 'dev'", specifier = ">=0.100.0" },
|
{ name = "fastapi", marker = "extra == 'dev'", specifier = ">=0.100.0" },
|
||||||
{ name = "fastapi", marker = "extra == 'fastapi'", specifier = ">=0.100.0" },
|
{ name = "fastapi", marker = "extra == 'fastapi'", specifier = ">=0.100.0" },
|
||||||
{ name = "httpx", marker = "extra == 'dev'", specifier = ">=0.27.0" },
|
{ name = "httpx", marker = "extra == 'dev'", specifier = ">=0.27.0" },
|
||||||
|
{ name = "pydantic", specifier = ">=2.0" },
|
||||||
{ name = "pyright", marker = "extra == 'dev'", specifier = ">=1.1.350" },
|
{ name = "pyright", marker = "extra == 'dev'", specifier = ">=1.1.350" },
|
||||||
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" },
|
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" },
|
||||||
{ name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.23.0" },
|
{ name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.23.0" },
|
||||||
|
|||||||
Reference in New Issue
Block a user