release: bump version to 0.3.0

- Refactor Redis backend connection handling and pool management
- Update algorithm implementations with improved type annotations
- Enhance config loader validation with stricter Pydantic schemas
- Improve decorator and middleware error handling
- Expand example scripts with better docstrings and usage patterns
- Add new 00_basic_usage.py example for quick start
- Reorganize examples directory structure
- Fix type annotation inconsistencies across core modules
- Update dependencies in pyproject.toml
This commit is contained in:
2026-03-17 20:55:38 +00:00
parent 492410614f
commit f3453cb0fc
51 changed files with 6507 additions and 166 deletions

View File

@@ -20,7 +20,7 @@ from fastapi_traffic.exceptions import (
RateLimitExceeded,
)
__version__ = "0.2.0"
__version__ = "0.3.0"
__all__ = [
"Algorithm",
"Backend",

View File

@@ -19,7 +19,7 @@ class RedisBackend(Backend):
def __init__(
self,
client: Redis[bytes],
client: Redis,
*,
key_prefix: str = "fastapi_traffic",
) -> None:
@@ -57,7 +57,8 @@ class RedisBackend(Backend):
msg = "redis package is required for RedisBackend. Install with: pip install redis"
raise ImportError(msg) from e
client: Redis[bytes] = Redis.from_url(url, **kwargs)
client: Redis = Redis.from_url(url, **kwargs) # pyright: ignore[reportUnknownMemberType] # fmt: skip
# note: No type stubs for redis-py, so we ignore the type errors
instance = cls(client, key_prefix=key_prefix)
instance._owns_client = True
return instance
@@ -119,7 +120,11 @@ class RedisBackend(Backend):
pattern = f"{self._key_prefix}:*"
cursor: int = 0
while True:
cursor, keys = await self._client.scan(cursor, match=pattern, count=100)
cursor, keys = (
await self._client.scan( # pyright: ignore[reportUnknownMemberType]
cursor, match=pattern, count=100
)
)
if keys:
await self._client.delete(*keys)
if cursor == 0:
@@ -134,11 +139,7 @@ class RedisBackend(Backend):
async def ping(self) -> bool:
"""Check if Redis is reachable."""
try:
await self._client.ping()
return True
except Exception:
return False
return await self._client.ping() # pyright: ignore[reportUnknownMemberType, reportGeneralTypeIssues, reportUnknownVariableType, reportReturnType] # fmt: skip
async def get_stats(self) -> dict[str, Any]:
"""Get statistics about the rate limit storage."""
@@ -147,12 +148,20 @@ class RedisBackend(Backend):
cursor: int = 0
count = 0
while True:
cursor, keys = await self._client.scan(cursor, match=pattern, count=100)
cursor, keys = (
await self._client.scan( # pyright: ignore[reportUnknownMemberType]
cursor, match=pattern, count=100
)
)
count += len(keys)
if cursor == 0:
break
info = await self._client.info("memory")
info: dict[str, Any] = (
await self._client.info( # pyright: ignore[reportUnknownMemberType]
"memory"
)
)
return {
"total_keys": count,
"used_memory": info.get("used_memory_human", "unknown"),

View File

@@ -89,8 +89,7 @@ class TokenBucketAlgorithm(BaseAlgorithm):
remaining=int(tokens),
reset_at=now + self.window_size,
window_size=self.window_size,
retry_after = (1 - tokens) / self.refill_rate
retry_after=(1 - tokens) / self.refill_rate,
)
tokens = float(state.get("tokens", self.burst_size))

View File

@@ -77,6 +77,6 @@ class GlobalConfig:
error_message: str = "Rate limit exceeded. Please try again later."
status_code: int = 429
skip_on_error: bool = False
exempt_ips: set[str] = field(default_factory=set)
exempt_paths: set[str] = field(default_factory=set)
exempt_ips: set[str] = field(default_factory=set[str])
exempt_paths: set[str] = field(default_factory=set[str])
headers_prefix: str = "X-RateLimit"

View File

@@ -5,7 +5,7 @@ from __future__ import annotations
import json
import os
from pathlib import Path
from typing import TYPE_CHECKING, Any, TypeVar
from typing import TYPE_CHECKING, Any, TypeVar, cast
from pydantic import BaseModel, ConfigDict, ValidationError, field_validator
@@ -300,6 +300,14 @@ class ConfigLoader:
_check_non_loadable(raw_config)
# Merge loadable overrides before validation so required fields can be supplied
non_loadable_overrides: dict[str, Any] = {}
for key, value in overrides.items():
if key in _NON_LOADABLE_FIELDS:
non_loadable_overrides[key] = value
elif key in _RATE_LIMIT_FIELDS:
raw_config[key] = value
try:
schema = _RateLimitSchema(**raw_config) # type: ignore[arg-type] # Pydantic coerces str→typed values at runtime
except ValidationError as e:
@@ -307,10 +315,8 @@ class ConfigLoader:
config_dict = schema.model_dump(exclude_defaults=True)
# Apply overrides
for key, value in overrides.items():
if key in _NON_LOADABLE_FIELDS or key in _RATE_LIMIT_FIELDS:
config_dict[key] = value
# Apply non-loadable overrides (callables, etc.)
config_dict.update(non_loadable_overrides)
# Ensure required field 'limit' is present
if "limit" not in config_dict:
@@ -363,7 +369,15 @@ class ConfigLoader:
msg = "JSON root must be an object"
raise ConfigurationError(msg)
_check_non_loadable(raw_config)
_check_non_loadable(cast("dict[str, Any]", raw_config))
# Merge loadable overrides before validation so required fields can be supplied
non_loadable_overrides: dict[str, Any] = {}
for key, value in overrides.items():
if key in _NON_LOADABLE_FIELDS:
non_loadable_overrides[key] = value
elif key in _RATE_LIMIT_FIELDS:
raw_config[key] = value
try:
schema = _RateLimitSchema(**raw_config) # type: ignore[arg-type] # Pydantic coerces str→typed values at runtime
@@ -372,10 +386,8 @@ class ConfigLoader:
config_dict = schema.model_dump(exclude_defaults=True)
# Apply overrides
for key, value in overrides.items():
if key in _NON_LOADABLE_FIELDS or key in _RATE_LIMIT_FIELDS:
config_dict[key] = value
# Apply non-loadable overrides (callables, etc.)
config_dict.update(non_loadable_overrides)
# Ensure required field 'limit' is present
if "limit" not in config_dict:
@@ -404,9 +416,7 @@ class ConfigLoader:
Raises:
ConfigurationError: If configuration is invalid.
"""
raw_config = self._extract_env_config(
"GLOBAL_", _GLOBAL_FIELDS, env_source
)
raw_config = self._extract_env_config("GLOBAL_", _GLOBAL_FIELDS, env_source)
_check_non_loadable(raw_config)
@@ -472,10 +482,10 @@ class ConfigLoader:
msg = "JSON root must be an object"
raise ConfigurationError(msg)
_check_non_loadable(raw_config)
_check_non_loadable(cast("dict[str, Any]", raw_config))
try:
schema = _GlobalSchema(**raw_config)
schema = _GlobalSchema(**cast("dict[str, Any]", raw_config))
except ValidationError as e:
raise ConfigurationError(_format_validation_error(e)) from e

View File

@@ -51,6 +51,7 @@ def rate_limit(
/,
) -> Callable[[F], F]: ...
def rate_limit(
limit: int,
window_size: float = 60.0,
@@ -243,7 +244,12 @@ class RateLimitDependency:
exempt_when=exempt_when,
)
async def __call__(self, request: Request) -> Any:
async def __call__(
self,
request: (
Request | Any
), # Actually Request, but using Any to avoid Pydantic schema issues
) -> Any:
"""Check rate limit and return info."""
limiter = get_limiter()
result = await limiter.hit(request, self._config)

View File

@@ -60,7 +60,7 @@ class TokenBucketState:
class SlidingWindowState:
"""State for sliding window algorithm."""
timestamps: list[float] = field(default_factory=list)
timestamps: list[float] = field(default_factory=list[float])
count: int = 0