release: bump version to 0.3.0
- Refactor Redis backend connection handling and pool management - Update algorithm implementations with improved type annotations - Enhance config loader validation with stricter Pydantic schemas - Improve decorator and middleware error handling - Expand example scripts with better docstrings and usage patterns - Add new 00_basic_usage.py example for quick start - Reorganize examples directory structure - Fix type annotation inconsistencies across core modules - Update dependencies in pyproject.toml
This commit is contained in:
@@ -20,7 +20,7 @@ from fastapi_traffic.exceptions import (
|
||||
RateLimitExceeded,
|
||||
)
|
||||
|
||||
__version__ = "0.2.0"
|
||||
__version__ = "0.3.0"
|
||||
__all__ = [
|
||||
"Algorithm",
|
||||
"Backend",
|
||||
|
||||
@@ -19,7 +19,7 @@ class RedisBackend(Backend):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: Redis[bytes],
|
||||
client: Redis,
|
||||
*,
|
||||
key_prefix: str = "fastapi_traffic",
|
||||
) -> None:
|
||||
@@ -57,7 +57,8 @@ class RedisBackend(Backend):
|
||||
msg = "redis package is required for RedisBackend. Install with: pip install redis"
|
||||
raise ImportError(msg) from e
|
||||
|
||||
client: Redis[bytes] = Redis.from_url(url, **kwargs)
|
||||
client: Redis = Redis.from_url(url, **kwargs) # pyright: ignore[reportUnknownMemberType] # fmt: skip
|
||||
# note: No type stubs for redis-py, so we ignore the type errors
|
||||
instance = cls(client, key_prefix=key_prefix)
|
||||
instance._owns_client = True
|
||||
return instance
|
||||
@@ -119,7 +120,11 @@ class RedisBackend(Backend):
|
||||
pattern = f"{self._key_prefix}:*"
|
||||
cursor: int = 0
|
||||
while True:
|
||||
cursor, keys = await self._client.scan(cursor, match=pattern, count=100)
|
||||
cursor, keys = (
|
||||
await self._client.scan( # pyright: ignore[reportUnknownMemberType]
|
||||
cursor, match=pattern, count=100
|
||||
)
|
||||
)
|
||||
if keys:
|
||||
await self._client.delete(*keys)
|
||||
if cursor == 0:
|
||||
@@ -134,11 +139,7 @@ class RedisBackend(Backend):
|
||||
|
||||
async def ping(self) -> bool:
|
||||
"""Check if Redis is reachable."""
|
||||
try:
|
||||
await self._client.ping()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
return await self._client.ping() # pyright: ignore[reportUnknownMemberType, reportGeneralTypeIssues, reportUnknownVariableType, reportReturnType] # fmt: skip
|
||||
|
||||
async def get_stats(self) -> dict[str, Any]:
|
||||
"""Get statistics about the rate limit storage."""
|
||||
@@ -147,12 +148,20 @@ class RedisBackend(Backend):
|
||||
cursor: int = 0
|
||||
count = 0
|
||||
while True:
|
||||
cursor, keys = await self._client.scan(cursor, match=pattern, count=100)
|
||||
cursor, keys = (
|
||||
await self._client.scan( # pyright: ignore[reportUnknownMemberType]
|
||||
cursor, match=pattern, count=100
|
||||
)
|
||||
)
|
||||
count += len(keys)
|
||||
if cursor == 0:
|
||||
break
|
||||
|
||||
info = await self._client.info("memory")
|
||||
info: dict[str, Any] = (
|
||||
await self._client.info( # pyright: ignore[reportUnknownMemberType]
|
||||
"memory"
|
||||
)
|
||||
)
|
||||
return {
|
||||
"total_keys": count,
|
||||
"used_memory": info.get("used_memory_human", "unknown"),
|
||||
|
||||
@@ -89,8 +89,7 @@ class TokenBucketAlgorithm(BaseAlgorithm):
|
||||
remaining=int(tokens),
|
||||
reset_at=now + self.window_size,
|
||||
window_size=self.window_size,
|
||||
retry_after = (1 - tokens) / self.refill_rate
|
||||
|
||||
retry_after=(1 - tokens) / self.refill_rate,
|
||||
)
|
||||
|
||||
tokens = float(state.get("tokens", self.burst_size))
|
||||
|
||||
@@ -77,6 +77,6 @@ class GlobalConfig:
|
||||
error_message: str = "Rate limit exceeded. Please try again later."
|
||||
status_code: int = 429
|
||||
skip_on_error: bool = False
|
||||
exempt_ips: set[str] = field(default_factory=set)
|
||||
exempt_paths: set[str] = field(default_factory=set)
|
||||
exempt_ips: set[str] = field(default_factory=set[str])
|
||||
exempt_paths: set[str] = field(default_factory=set[str])
|
||||
headers_prefix: str = "X-RateLimit"
|
||||
|
||||
@@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, TypeVar, cast
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, ValidationError, field_validator
|
||||
|
||||
@@ -300,6 +300,14 @@ class ConfigLoader:
|
||||
|
||||
_check_non_loadable(raw_config)
|
||||
|
||||
# Merge loadable overrides before validation so required fields can be supplied
|
||||
non_loadable_overrides: dict[str, Any] = {}
|
||||
for key, value in overrides.items():
|
||||
if key in _NON_LOADABLE_FIELDS:
|
||||
non_loadable_overrides[key] = value
|
||||
elif key in _RATE_LIMIT_FIELDS:
|
||||
raw_config[key] = value
|
||||
|
||||
try:
|
||||
schema = _RateLimitSchema(**raw_config) # type: ignore[arg-type] # Pydantic coerces str→typed values at runtime
|
||||
except ValidationError as e:
|
||||
@@ -307,10 +315,8 @@ class ConfigLoader:
|
||||
|
||||
config_dict = schema.model_dump(exclude_defaults=True)
|
||||
|
||||
# Apply overrides
|
||||
for key, value in overrides.items():
|
||||
if key in _NON_LOADABLE_FIELDS or key in _RATE_LIMIT_FIELDS:
|
||||
config_dict[key] = value
|
||||
# Apply non-loadable overrides (callables, etc.)
|
||||
config_dict.update(non_loadable_overrides)
|
||||
|
||||
# Ensure required field 'limit' is present
|
||||
if "limit" not in config_dict:
|
||||
@@ -363,7 +369,15 @@ class ConfigLoader:
|
||||
msg = "JSON root must be an object"
|
||||
raise ConfigurationError(msg)
|
||||
|
||||
_check_non_loadable(raw_config)
|
||||
_check_non_loadable(cast("dict[str, Any]", raw_config))
|
||||
|
||||
# Merge loadable overrides before validation so required fields can be supplied
|
||||
non_loadable_overrides: dict[str, Any] = {}
|
||||
for key, value in overrides.items():
|
||||
if key in _NON_LOADABLE_FIELDS:
|
||||
non_loadable_overrides[key] = value
|
||||
elif key in _RATE_LIMIT_FIELDS:
|
||||
raw_config[key] = value
|
||||
|
||||
try:
|
||||
schema = _RateLimitSchema(**raw_config) # type: ignore[arg-type] # Pydantic coerces str→typed values at runtime
|
||||
@@ -372,10 +386,8 @@ class ConfigLoader:
|
||||
|
||||
config_dict = schema.model_dump(exclude_defaults=True)
|
||||
|
||||
# Apply overrides
|
||||
for key, value in overrides.items():
|
||||
if key in _NON_LOADABLE_FIELDS or key in _RATE_LIMIT_FIELDS:
|
||||
config_dict[key] = value
|
||||
# Apply non-loadable overrides (callables, etc.)
|
||||
config_dict.update(non_loadable_overrides)
|
||||
|
||||
# Ensure required field 'limit' is present
|
||||
if "limit" not in config_dict:
|
||||
@@ -404,9 +416,7 @@ class ConfigLoader:
|
||||
Raises:
|
||||
ConfigurationError: If configuration is invalid.
|
||||
"""
|
||||
raw_config = self._extract_env_config(
|
||||
"GLOBAL_", _GLOBAL_FIELDS, env_source
|
||||
)
|
||||
raw_config = self._extract_env_config("GLOBAL_", _GLOBAL_FIELDS, env_source)
|
||||
|
||||
_check_non_loadable(raw_config)
|
||||
|
||||
@@ -472,10 +482,10 @@ class ConfigLoader:
|
||||
msg = "JSON root must be an object"
|
||||
raise ConfigurationError(msg)
|
||||
|
||||
_check_non_loadable(raw_config)
|
||||
_check_non_loadable(cast("dict[str, Any]", raw_config))
|
||||
|
||||
try:
|
||||
schema = _GlobalSchema(**raw_config)
|
||||
schema = _GlobalSchema(**cast("dict[str, Any]", raw_config))
|
||||
except ValidationError as e:
|
||||
raise ConfigurationError(_format_validation_error(e)) from e
|
||||
|
||||
|
||||
@@ -51,6 +51,7 @@ def rate_limit(
|
||||
/,
|
||||
) -> Callable[[F], F]: ...
|
||||
|
||||
|
||||
def rate_limit(
|
||||
limit: int,
|
||||
window_size: float = 60.0,
|
||||
@@ -243,7 +244,12 @@ class RateLimitDependency:
|
||||
exempt_when=exempt_when,
|
||||
)
|
||||
|
||||
async def __call__(self, request: Request) -> Any:
|
||||
async def __call__(
|
||||
self,
|
||||
request: (
|
||||
Request | Any
|
||||
), # Actually Request, but using Any to avoid Pydantic schema issues
|
||||
) -> Any:
|
||||
"""Check rate limit and return info."""
|
||||
limiter = get_limiter()
|
||||
result = await limiter.hit(request, self._config)
|
||||
|
||||
@@ -60,7 +60,7 @@ class TokenBucketState:
|
||||
class SlidingWindowState:
|
||||
"""State for sliding window algorithm."""
|
||||
|
||||
timestamps: list[float] = field(default_factory=list)
|
||||
timestamps: list[float] = field(default_factory=list[float])
|
||||
count: int = 0
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user