Initial commit: fastapi-traffic rate limiting library
- Core rate limiting with multiple algorithms (sliding window, token bucket, etc.) - SQLite and memory backends - Decorator and dependency injection patterns - Middleware support - Example usage files
This commit is contained in:
36
fastapi_traffic/__init__.py
Normal file
36
fastapi_traffic/__init__.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""FastAPI Traffic - Production-grade rate limiting for FastAPI."""
|
||||
|
||||
from fastapi_traffic.core.decorator import rate_limit
|
||||
from fastapi_traffic.core.limiter import RateLimiter
|
||||
from fastapi_traffic.core.config import RateLimitConfig
|
||||
from fastapi_traffic.core.algorithms import Algorithm
|
||||
from fastapi_traffic.backends.base import Backend
|
||||
from fastapi_traffic.backends.memory import MemoryBackend
|
||||
from fastapi_traffic.backends.sqlite import SQLiteBackend
|
||||
from fastapi_traffic.exceptions import (
|
||||
RateLimitExceeded,
|
||||
BackendError,
|
||||
ConfigurationError,
|
||||
)
|
||||
|
||||
__version__ = "0.1.0"
|
||||
__all__ = [
|
||||
"rate_limit",
|
||||
"RateLimiter",
|
||||
"RateLimitConfig",
|
||||
"Algorithm",
|
||||
"Backend",
|
||||
"MemoryBackend",
|
||||
"SQLiteBackend",
|
||||
"RateLimitExceeded",
|
||||
"BackendError",
|
||||
"ConfigurationError",
|
||||
]
|
||||
|
||||
# Optional Redis backend
|
||||
try:
|
||||
from fastapi_traffic.backends.redis import RedisBackend
|
||||
|
||||
__all__.append("RedisBackend")
|
||||
except ImportError:
|
||||
pass
|
||||
19
fastapi_traffic/backends/__init__.py
Normal file
19
fastapi_traffic/backends/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Backend implementations for rate limit storage."""
|
||||
|
||||
from fastapi_traffic.backends.base import Backend
|
||||
from fastapi_traffic.backends.memory import MemoryBackend
|
||||
from fastapi_traffic.backends.sqlite import SQLiteBackend
|
||||
|
||||
__all__ = [
|
||||
"Backend",
|
||||
"MemoryBackend",
|
||||
"SQLiteBackend",
|
||||
]
|
||||
|
||||
# Optional Redis backend
|
||||
try:
|
||||
from fastapi_traffic.backends.redis import RedisBackend
|
||||
|
||||
__all__.append("RedisBackend")
|
||||
except ImportError:
|
||||
pass
|
||||
89
fastapi_traffic/backends/base.py
Normal file
89
fastapi_traffic/backends/base.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""Abstract base class for rate limit storage backends."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
||||
class Backend(ABC):
|
||||
"""Abstract base class for rate limit storage backends."""
|
||||
|
||||
@abstractmethod
|
||||
async def get(self, key: str) -> dict[str, Any] | None:
|
||||
"""Get the current state for a key.
|
||||
|
||||
Args:
|
||||
key: The rate limit key.
|
||||
|
||||
Returns:
|
||||
The stored state dictionary or None if not found.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def set(self, key: str, value: dict[str, Any], *, ttl: float) -> None:
|
||||
"""Set the state for a key with TTL.
|
||||
|
||||
Args:
|
||||
key: The rate limit key.
|
||||
value: The state dictionary to store.
|
||||
ttl: Time-to-live in seconds.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def delete(self, key: str) -> None:
|
||||
"""Delete the state for a key.
|
||||
|
||||
Args:
|
||||
key: The rate limit key.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def exists(self, key: str) -> bool:
|
||||
"""Check if a key exists.
|
||||
|
||||
Args:
|
||||
key: The rate limit key.
|
||||
|
||||
Returns:
|
||||
True if the key exists, False otherwise.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def increment(self, key: str, amount: int = 1) -> int:
|
||||
"""Atomically increment a counter.
|
||||
|
||||
Args:
|
||||
key: The rate limit key.
|
||||
amount: The amount to increment by.
|
||||
|
||||
Returns:
|
||||
The new value after incrementing.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def clear(self) -> None:
|
||||
"""Clear all rate limit data."""
|
||||
...
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the backend connection."""
|
||||
pass
|
||||
|
||||
async def __aenter__(self) -> Backend:
|
||||
"""Async context manager entry."""
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: Any,
|
||||
) -> None:
|
||||
"""Async context manager exit."""
|
||||
await self.close()
|
||||
139
fastapi_traffic/backends/memory.py
Normal file
139
fastapi_traffic/backends/memory.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""In-memory backend for rate limiting - suitable for single-process applications."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from typing import Any
|
||||
|
||||
from fastapi_traffic.backends.base import Backend
|
||||
|
||||
|
||||
class MemoryBackend(Backend):
|
||||
"""Thread-safe in-memory backend with LRU eviction and TTL support."""
|
||||
|
||||
__slots__ = ("_data", "_lock", "_max_size", "_cleanup_interval", "_cleanup_task")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
max_size: int = 10000,
|
||||
cleanup_interval: float = 60.0,
|
||||
) -> None:
|
||||
"""Initialize the memory backend.
|
||||
|
||||
Args:
|
||||
max_size: Maximum number of entries to store (LRU eviction).
|
||||
cleanup_interval: Interval in seconds for cleaning expired entries.
|
||||
"""
|
||||
self._data: OrderedDict[str, tuple[dict[str, Any], float]] = OrderedDict()
|
||||
self._lock = asyncio.Lock()
|
||||
self._max_size = max_size
|
||||
self._cleanup_interval = cleanup_interval
|
||||
self._cleanup_task: asyncio.Task[None] | None = None
|
||||
|
||||
async def start_cleanup(self) -> None:
|
||||
"""Start the background cleanup task."""
|
||||
if self._cleanup_task is None:
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||
|
||||
async def _cleanup_loop(self) -> None:
|
||||
"""Background task to clean up expired entries."""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(self._cleanup_interval)
|
||||
await self._cleanup_expired()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _cleanup_expired(self) -> None:
|
||||
"""Remove expired entries."""
|
||||
now = time.time()
|
||||
async with self._lock:
|
||||
expired_keys = [
|
||||
key for key, (_, expires_at) in self._data.items() if expires_at <= now
|
||||
]
|
||||
for key in expired_keys:
|
||||
del self._data[key]
|
||||
|
||||
def _evict_if_needed(self) -> None:
|
||||
"""Evict oldest entries if over max size (must be called with lock held)."""
|
||||
while len(self._data) > self._max_size:
|
||||
self._data.popitem(last=False)
|
||||
|
||||
async def get(self, key: str) -> dict[str, Any] | None:
|
||||
"""Get the current state for a key."""
|
||||
async with self._lock:
|
||||
if key not in self._data:
|
||||
return None
|
||||
|
||||
value, expires_at = self._data[key]
|
||||
if expires_at <= time.time():
|
||||
del self._data[key]
|
||||
return None
|
||||
|
||||
self._data.move_to_end(key)
|
||||
return value.copy()
|
||||
|
||||
async def set(self, key: str, value: dict[str, Any], *, ttl: float) -> None:
|
||||
"""Set the state for a key with TTL."""
|
||||
expires_at = time.time() + ttl
|
||||
async with self._lock:
|
||||
self._data[key] = (value.copy(), expires_at)
|
||||
self._data.move_to_end(key)
|
||||
self._evict_if_needed()
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
"""Delete the state for a key."""
|
||||
async with self._lock:
|
||||
self._data.pop(key, None)
|
||||
|
||||
async def exists(self, key: str) -> bool:
|
||||
"""Check if a key exists and is not expired."""
|
||||
async with self._lock:
|
||||
if key not in self._data:
|
||||
return False
|
||||
|
||||
_, expires_at = self._data[key]
|
||||
if expires_at <= time.time():
|
||||
del self._data[key]
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def increment(self, key: str, amount: int = 1) -> int:
|
||||
"""Atomically increment a counter."""
|
||||
async with self._lock:
|
||||
if key in self._data:
|
||||
value, expires_at = self._data[key]
|
||||
if expires_at > time.time():
|
||||
current = int(value.get("count", 0))
|
||||
new_value = current + amount
|
||||
value["count"] = new_value
|
||||
self._data[key] = (value, expires_at)
|
||||
return new_value
|
||||
|
||||
return amount
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear all rate limit data."""
|
||||
async with self._lock:
|
||||
self._data.clear()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Stop cleanup task and clear data."""
|
||||
if self._cleanup_task is not None:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._cleanup_task = None
|
||||
await self.clear()
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of stored entries."""
|
||||
return len(self._data)
|
||||
232
fastapi_traffic/backends/redis.py
Normal file
232
fastapi_traffic/backends/redis.py
Normal file
@@ -0,0 +1,232 @@
|
||||
"""Redis backend for rate limiting - distributed storage for multi-node deployments."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from fastapi_traffic.backends.base import Backend
|
||||
from fastapi_traffic.exceptions import BackendError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.asyncio import Redis
|
||||
|
||||
|
||||
class RedisBackend(Backend):
|
||||
"""Redis-based backend for distributed rate limiting."""
|
||||
|
||||
__slots__ = ("_client", "_key_prefix", "_owns_client")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: Redis[bytes],
|
||||
*,
|
||||
key_prefix: str = "fastapi_traffic",
|
||||
) -> None:
|
||||
"""Initialize the Redis backend.
|
||||
|
||||
Args:
|
||||
client: An async Redis client instance.
|
||||
key_prefix: Prefix for all rate limit keys.
|
||||
"""
|
||||
self._client = client
|
||||
self._key_prefix = key_prefix
|
||||
self._owns_client = False
|
||||
|
||||
@classmethod
|
||||
async def from_url(
|
||||
cls,
|
||||
url: str = "redis://localhost:6379/0",
|
||||
*,
|
||||
key_prefix: str = "fastapi_traffic",
|
||||
**kwargs: Any,
|
||||
) -> RedisBackend:
|
||||
"""Create a RedisBackend from a Redis URL.
|
||||
|
||||
Args:
|
||||
url: Redis connection URL.
|
||||
key_prefix: Prefix for all rate limit keys.
|
||||
**kwargs: Additional arguments passed to Redis.from_url().
|
||||
|
||||
Returns:
|
||||
A new RedisBackend instance.
|
||||
"""
|
||||
try:
|
||||
from redis.asyncio import Redis
|
||||
except ImportError as e:
|
||||
msg = "redis package is required for RedisBackend. Install with: pip install redis"
|
||||
raise ImportError(msg) from e
|
||||
|
||||
client: Redis[bytes] = Redis.from_url(url, **kwargs)
|
||||
instance = cls(client, key_prefix=key_prefix)
|
||||
instance._owns_client = True
|
||||
return instance
|
||||
|
||||
def _make_key(self, key: str) -> str:
|
||||
"""Create a prefixed key."""
|
||||
return f"{self._key_prefix}:{key}"
|
||||
|
||||
async def get(self, key: str) -> dict[str, Any] | None:
|
||||
"""Get the current state for a key."""
|
||||
try:
|
||||
full_key = self._make_key(key)
|
||||
data = await self._client.get(full_key)
|
||||
if data is None:
|
||||
return None
|
||||
result: dict[str, Any] = json.loads(data)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise BackendError(f"Failed to get key {key}", original_error=e)
|
||||
|
||||
async def set(self, key: str, value: dict[str, Any], *, ttl: float) -> None:
|
||||
"""Set the state for a key with TTL."""
|
||||
try:
|
||||
full_key = self._make_key(key)
|
||||
data = json.dumps(value)
|
||||
await self._client.setex(full_key, int(ttl) + 1, data)
|
||||
except Exception as e:
|
||||
raise BackendError(f"Failed to set key {key}", original_error=e)
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
"""Delete the state for a key."""
|
||||
try:
|
||||
full_key = self._make_key(key)
|
||||
await self._client.delete(full_key)
|
||||
except Exception as e:
|
||||
raise BackendError(f"Failed to delete key {key}", original_error=e)
|
||||
|
||||
async def exists(self, key: str) -> bool:
|
||||
"""Check if a key exists."""
|
||||
try:
|
||||
full_key = self._make_key(key)
|
||||
result = await self._client.exists(full_key)
|
||||
return bool(result)
|
||||
except Exception as e:
|
||||
raise BackendError(f"Failed to check key {key}", original_error=e)
|
||||
|
||||
async def increment(self, key: str, amount: int = 1) -> int:
|
||||
"""Atomically increment a counter using Redis INCRBY."""
|
||||
try:
|
||||
full_key = self._make_key(key)
|
||||
result = await self._client.incrby(full_key, amount)
|
||||
return int(result)
|
||||
except Exception as e:
|
||||
raise BackendError(f"Failed to increment key {key}", original_error=e)
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear all rate limit data with this prefix."""
|
||||
try:
|
||||
pattern = f"{self._key_prefix}:*"
|
||||
cursor: int = 0
|
||||
while True:
|
||||
cursor, keys = await self._client.scan(cursor, match=pattern, count=100)
|
||||
if keys:
|
||||
await self._client.delete(*keys)
|
||||
if cursor == 0:
|
||||
break
|
||||
except Exception as e:
|
||||
raise BackendError("Failed to clear rate limits", original_error=e)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the Redis connection if we own it."""
|
||||
if self._owns_client:
|
||||
await self._client.aclose()
|
||||
|
||||
async def ping(self) -> bool:
|
||||
"""Check if Redis is reachable."""
|
||||
try:
|
||||
await self._client.ping()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def get_stats(self) -> dict[str, Any]:
|
||||
"""Get statistics about the rate limit storage."""
|
||||
try:
|
||||
pattern = f"{self._key_prefix}:*"
|
||||
cursor: int = 0
|
||||
count = 0
|
||||
while True:
|
||||
cursor, keys = await self._client.scan(cursor, match=pattern, count=100)
|
||||
count += len(keys)
|
||||
if cursor == 0:
|
||||
break
|
||||
|
||||
info = await self._client.info("memory")
|
||||
return {
|
||||
"total_keys": count,
|
||||
"used_memory": info.get("used_memory_human", "unknown"),
|
||||
"key_prefix": self._key_prefix,
|
||||
}
|
||||
except Exception as e:
|
||||
raise BackendError("Failed to get stats", original_error=e)
|
||||
|
||||
|
||||
# Lua scripts for atomic operations
|
||||
SLIDING_WINDOW_SCRIPT = """
|
||||
local key = KEYS[1]
|
||||
local now = tonumber(ARGV[1])
|
||||
local window_size = tonumber(ARGV[2])
|
||||
local limit = tonumber(ARGV[3])
|
||||
local window_start = now - window_size
|
||||
|
||||
-- Remove expired entries
|
||||
redis.call('ZREMRANGEBYSCORE', key, '-inf', window_start)
|
||||
|
||||
-- Count current entries
|
||||
local count = redis.call('ZCARD', key)
|
||||
|
||||
if count < limit then
|
||||
-- Add new entry
|
||||
redis.call('ZADD', key, now, now .. ':' .. math.random())
|
||||
redis.call('EXPIRE', key, math.ceil(window_size) + 1)
|
||||
return {1, limit - count - 1}
|
||||
else
|
||||
-- Get oldest entry for retry-after calculation
|
||||
local oldest = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES')
|
||||
local retry_after = 0
|
||||
if #oldest > 0 then
|
||||
retry_after = oldest[2] + window_size - now
|
||||
end
|
||||
return {0, 0, retry_after}
|
||||
end
|
||||
"""
|
||||
|
||||
TOKEN_BUCKET_SCRIPT = """
|
||||
local key = KEYS[1]
|
||||
local now = tonumber(ARGV[1])
|
||||
local bucket_size = tonumber(ARGV[2])
|
||||
local refill_rate = tonumber(ARGV[3])
|
||||
local ttl = tonumber(ARGV[4])
|
||||
|
||||
local data = redis.call('GET', key)
|
||||
local tokens, last_update
|
||||
|
||||
if data then
|
||||
local decoded = cjson.decode(data)
|
||||
tokens = decoded.tokens
|
||||
last_update = decoded.last_update
|
||||
else
|
||||
tokens = bucket_size
|
||||
last_update = now
|
||||
end
|
||||
|
||||
-- Refill tokens
|
||||
local elapsed = now - last_update
|
||||
tokens = math.min(bucket_size, tokens + elapsed * refill_rate)
|
||||
|
||||
local allowed = 0
|
||||
local retry_after = 0
|
||||
|
||||
if tokens >= 1 then
|
||||
tokens = tokens - 1
|
||||
allowed = 1
|
||||
else
|
||||
retry_after = (1 - tokens) / refill_rate
|
||||
end
|
||||
|
||||
-- Save state
|
||||
redis.call('SETEX', key, ttl, cjson.encode({tokens = tokens, last_update = now}))
|
||||
|
||||
return {allowed, math.floor(tokens), retry_after}
|
||||
"""
|
||||
298
fastapi_traffic/backends/sqlite.py
Normal file
298
fastapi_traffic/backends/sqlite.py
Normal file
@@ -0,0 +1,298 @@
|
||||
"""SQLite backend for rate limiting - persistent storage for single-node deployments."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sqlite3
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from fastapi_traffic.backends.base import Backend
|
||||
from fastapi_traffic.exceptions import BackendError
|
||||
|
||||
|
||||
class SQLiteBackend(Backend):
|
||||
"""SQLite-based backend with connection pooling and async support."""
|
||||
|
||||
__slots__ = (
|
||||
"_db_path",
|
||||
"_connection",
|
||||
"_lock",
|
||||
"_cleanup_interval",
|
||||
"_cleanup_task",
|
||||
"_pool_size",
|
||||
"_connections",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_path: str | Path = ":memory:",
|
||||
*,
|
||||
cleanup_interval: float = 300.0,
|
||||
pool_size: int = 5,
|
||||
) -> None:
|
||||
"""Initialize the SQLite backend.
|
||||
|
||||
Args:
|
||||
db_path: Path to SQLite database file or ":memory:" for in-memory.
|
||||
cleanup_interval: Interval in seconds for cleaning expired entries.
|
||||
pool_size: Number of connections in the pool.
|
||||
"""
|
||||
self._db_path = str(db_path)
|
||||
self._connection: sqlite3.Connection | None = None
|
||||
self._lock = asyncio.Lock()
|
||||
self._cleanup_interval = cleanup_interval
|
||||
self._cleanup_task: asyncio.Task[None] | None = None
|
||||
self._pool_size = pool_size
|
||||
self._connections: list[sqlite3.Connection] = []
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the database and create tables."""
|
||||
await self._ensure_connection()
|
||||
await self._create_tables()
|
||||
if self._cleanup_task is None:
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||
|
||||
async def _ensure_connection(self) -> sqlite3.Connection:
|
||||
"""Ensure a database connection exists."""
|
||||
if self._connection is None:
|
||||
loop = asyncio.get_event_loop()
|
||||
self._connection = await loop.run_in_executor(
|
||||
None, self._create_connection
|
||||
)
|
||||
assert self._connection is not None
|
||||
return self._connection
|
||||
|
||||
def _create_connection(self) -> sqlite3.Connection:
|
||||
"""Create a new SQLite connection with optimized settings."""
|
||||
conn = sqlite3.connect(
|
||||
self._db_path,
|
||||
check_same_thread=False,
|
||||
isolation_level=None,
|
||||
)
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA synchronous=NORMAL")
|
||||
conn.execute("PRAGMA cache_size=10000")
|
||||
conn.execute("PRAGMA temp_store=MEMORY")
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
|
||||
async def _create_tables(self) -> None:
|
||||
"""Create the rate limit tables."""
|
||||
conn = await self._ensure_connection()
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, self._create_tables_sync, conn)
|
||||
|
||||
def _create_tables_sync(self, conn: sqlite3.Connection) -> None:
|
||||
"""Synchronously create tables."""
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS rate_limits (
|
||||
key TEXT PRIMARY KEY,
|
||||
data TEXT NOT NULL,
|
||||
expires_at REAL NOT NULL
|
||||
)
|
||||
""")
|
||||
conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_expires_at ON rate_limits(expires_at)
|
||||
""")
|
||||
|
||||
async def _cleanup_loop(self) -> None:
|
||||
"""Background task to clean up expired entries."""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(self._cleanup_interval)
|
||||
await self._cleanup_expired()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _cleanup_expired(self) -> None:
|
||||
"""Remove expired entries."""
|
||||
try:
|
||||
conn = await self._ensure_connection()
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
lambda: conn.execute(
|
||||
"DELETE FROM rate_limits WHERE expires_at <= ?", (time.time(),)
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
raise BackendError("Failed to cleanup expired entries", original_error=e)
|
||||
|
||||
async def get(self, key: str) -> dict[str, Any] | None:
|
||||
"""Get the current state for a key."""
|
||||
try:
|
||||
conn = await self._ensure_connection()
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def _get() -> dict[str, Any] | None:
|
||||
cursor = conn.execute(
|
||||
"SELECT data, expires_at FROM rate_limits WHERE key = ?",
|
||||
(key,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
|
||||
expires_at = row["expires_at"]
|
||||
if expires_at <= time.time():
|
||||
conn.execute("DELETE FROM rate_limits WHERE key = ?", (key,))
|
||||
return None
|
||||
|
||||
data: dict[str, Any] = json.loads(row["data"])
|
||||
return data
|
||||
|
||||
return await loop.run_in_executor(None, _get)
|
||||
except Exception as e:
|
||||
raise BackendError(f"Failed to get key {key}", original_error=e)
|
||||
|
||||
async def set(self, key: str, value: dict[str, Any], *, ttl: float) -> None:
|
||||
"""Set the state for a key with TTL."""
|
||||
try:
|
||||
conn = await self._ensure_connection()
|
||||
loop = asyncio.get_event_loop()
|
||||
expires_at = time.time() + ttl
|
||||
data_json = json.dumps(value)
|
||||
|
||||
def _set() -> None:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO rate_limits (key, data, expires_at)
|
||||
VALUES (?, ?, ?)
|
||||
""",
|
||||
(key, data_json, expires_at),
|
||||
)
|
||||
|
||||
await loop.run_in_executor(None, _set)
|
||||
except Exception as e:
|
||||
raise BackendError(f"Failed to set key {key}", original_error=e)
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
"""Delete the state for a key."""
|
||||
try:
|
||||
conn = await self._ensure_connection()
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
lambda: conn.execute("DELETE FROM rate_limits WHERE key = ?", (key,)),
|
||||
)
|
||||
except Exception as e:
|
||||
raise BackendError(f"Failed to delete key {key}", original_error=e)
|
||||
|
||||
async def exists(self, key: str) -> bool:
|
||||
"""Check if a key exists and is not expired."""
|
||||
try:
|
||||
conn = await self._ensure_connection()
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def _exists() -> bool:
|
||||
cursor = conn.execute(
|
||||
"SELECT 1 FROM rate_limits WHERE key = ? AND expires_at > ?",
|
||||
(key, time.time()),
|
||||
)
|
||||
return cursor.fetchone() is not None
|
||||
|
||||
return await loop.run_in_executor(None, _exists)
|
||||
except Exception as e:
|
||||
raise BackendError(f"Failed to check key {key}", original_error=e)
|
||||
|
||||
async def increment(self, key: str, amount: int = 1) -> int:
|
||||
"""Atomically increment a counter."""
|
||||
async with self._lock:
|
||||
try:
|
||||
conn = await self._ensure_connection()
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def _increment() -> int:
|
||||
cursor = conn.execute(
|
||||
"SELECT data, expires_at FROM rate_limits WHERE key = ?",
|
||||
(key,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
|
||||
if row is None or row["expires_at"] <= time.time():
|
||||
return amount
|
||||
|
||||
data: dict[str, Any] = json.loads(row["data"])
|
||||
current = int(data.get("count", 0))
|
||||
new_value = current + amount
|
||||
data["count"] = new_value
|
||||
|
||||
conn.execute(
|
||||
"UPDATE rate_limits SET data = ? WHERE key = ?",
|
||||
(json.dumps(data), key),
|
||||
)
|
||||
return new_value
|
||||
|
||||
return await loop.run_in_executor(None, _increment)
|
||||
except Exception as e:
|
||||
raise BackendError(f"Failed to increment key {key}", original_error=e)
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear all rate limit data."""
|
||||
try:
|
||||
conn = await self._ensure_connection()
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
None, lambda: conn.execute("DELETE FROM rate_limits")
|
||||
)
|
||||
except Exception as e:
|
||||
raise BackendError("Failed to clear rate limits", original_error=e)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the database connection."""
|
||||
if self._cleanup_task is not None:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._cleanup_task = None
|
||||
|
||||
if self._connection is not None:
|
||||
self._connection.close()
|
||||
self._connection = None
|
||||
|
||||
for conn in self._connections:
|
||||
conn.close()
|
||||
self._connections.clear()
|
||||
|
||||
async def vacuum(self) -> None:
|
||||
"""Optimize the database by running VACUUM."""
|
||||
try:
|
||||
conn = await self._ensure_connection()
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, lambda: conn.execute("VACUUM"))
|
||||
except Exception as e:
|
||||
raise BackendError("Failed to vacuum database", original_error=e)
|
||||
|
||||
async def get_stats(self) -> dict[str, Any]:
|
||||
"""Get statistics about the rate limit storage."""
|
||||
try:
|
||||
conn = await self._ensure_connection()
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def _stats() -> dict[str, Any]:
|
||||
cursor = conn.execute("SELECT COUNT(*) as total FROM rate_limits")
|
||||
total = cursor.fetchone()["total"]
|
||||
|
||||
cursor = conn.execute(
|
||||
"SELECT COUNT(*) as active FROM rate_limits WHERE expires_at > ?",
|
||||
(time.time(),),
|
||||
)
|
||||
active = cursor.fetchone()["active"]
|
||||
|
||||
return {
|
||||
"total_entries": total,
|
||||
"active_entries": active,
|
||||
"expired_entries": total - active,
|
||||
"db_path": self._db_path,
|
||||
}
|
||||
|
||||
return await loop.run_in_executor(None, _stats)
|
||||
except Exception as e:
|
||||
raise BackendError("Failed to get stats", original_error=e)
|
||||
16
fastapi_traffic/core/__init__.py
Normal file
16
fastapi_traffic/core/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Core rate limiting components."""
|
||||
|
||||
from fastapi_traffic.core.algorithms import Algorithm
|
||||
from fastapi_traffic.core.config import RateLimitConfig
|
||||
from fastapi_traffic.core.decorator import rate_limit
|
||||
from fastapi_traffic.core.limiter import RateLimiter
|
||||
from fastapi_traffic.core.models import RateLimitInfo, RateLimitResult
|
||||
|
||||
__all__ = [
|
||||
"Algorithm",
|
||||
"RateLimitConfig",
|
||||
"rate_limit",
|
||||
"RateLimiter",
|
||||
"RateLimitInfo",
|
||||
"RateLimitResult",
|
||||
]
|
||||
466
fastapi_traffic/core/algorithms.py
Normal file
466
fastapi_traffic/core/algorithms.py
Normal file
@@ -0,0 +1,466 @@
|
||||
"""Rate limiting algorithms implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi_traffic.core.models import RateLimitInfo
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi_traffic.backends.base import Backend
|
||||
|
||||
|
||||
class Algorithm(str, Enum):
|
||||
"""Available rate limiting algorithms."""
|
||||
|
||||
TOKEN_BUCKET = "token_bucket"
|
||||
SLIDING_WINDOW = "sliding_window"
|
||||
FIXED_WINDOW = "fixed_window"
|
||||
LEAKY_BUCKET = "leaky_bucket"
|
||||
SLIDING_WINDOW_COUNTER = "sliding_window_counter"
|
||||
|
||||
|
||||
class BaseAlgorithm(ABC):
|
||||
"""Base class for rate limiting algorithms."""
|
||||
|
||||
__slots__ = ("limit", "window_size", "backend", "burst_size")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
limit: int,
|
||||
window_size: float,
|
||||
backend: Backend,
|
||||
*,
|
||||
burst_size: int | None = None,
|
||||
) -> None:
|
||||
self.limit = limit
|
||||
self.window_size = window_size
|
||||
self.backend = backend
|
||||
self.burst_size = burst_size or limit
|
||||
|
||||
@abstractmethod
|
||||
async def check(self, key: str) -> tuple[bool, RateLimitInfo]:
|
||||
"""Check if request is allowed and update state."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def reset(self, key: str) -> None:
|
||||
"""Reset the rate limit state for a key."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_state(self, key: str) -> RateLimitInfo | None:
|
||||
"""Get current state without consuming a token."""
|
||||
...
|
||||
|
||||
|
||||
class TokenBucketAlgorithm(BaseAlgorithm):
|
||||
"""Token bucket algorithm - allows bursts up to bucket capacity."""
|
||||
|
||||
__slots__ = ("refill_rate",)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
limit: int,
|
||||
window_size: float,
|
||||
backend: Backend,
|
||||
*,
|
||||
burst_size: int | None = None,
|
||||
) -> None:
|
||||
super().__init__(limit, window_size, backend, burst_size=burst_size)
|
||||
self.refill_rate = limit / window_size
|
||||
|
||||
async def check(self, key: str) -> tuple[bool, RateLimitInfo]:
|
||||
now = time.time()
|
||||
state = await self.backend.get(key)
|
||||
|
||||
if state is None:
|
||||
tokens = float(self.burst_size - 1)
|
||||
await self.backend.set(
|
||||
key,
|
||||
{"tokens": tokens, "last_update": now},
|
||||
ttl=self.window_size * 2,
|
||||
)
|
||||
return True, RateLimitInfo(
|
||||
limit=self.limit,
|
||||
remaining=int(tokens),
|
||||
reset_at=now + self.window_size,
|
||||
window_size=self.window_size,
|
||||
)
|
||||
|
||||
tokens = float(state.get("tokens", self.burst_size))
|
||||
last_update = float(state.get("last_update", now))
|
||||
|
||||
elapsed = now - last_update
|
||||
tokens = min(self.burst_size, tokens + elapsed * self.refill_rate)
|
||||
|
||||
if tokens >= 1:
|
||||
tokens -= 1
|
||||
allowed = True
|
||||
retry_after = None
|
||||
else:
|
||||
allowed = False
|
||||
retry_after = (1 - tokens) / self.refill_rate
|
||||
|
||||
await self.backend.set(
|
||||
key,
|
||||
{"tokens": tokens, "last_update": now},
|
||||
ttl=self.window_size * 2,
|
||||
)
|
||||
|
||||
return allowed, RateLimitInfo(
|
||||
limit=self.limit,
|
||||
remaining=int(tokens),
|
||||
reset_at=now + (self.burst_size - tokens) / self.refill_rate,
|
||||
retry_after=retry_after,
|
||||
window_size=self.window_size,
|
||||
)
|
||||
|
||||
async def reset(self, key: str) -> None:
|
||||
await self.backend.delete(key)
|
||||
|
||||
async def get_state(self, key: str) -> RateLimitInfo | None:
|
||||
now = time.time()
|
||||
state = await self.backend.get(key)
|
||||
|
||||
if state is None:
|
||||
return None
|
||||
|
||||
tokens = float(state.get("tokens", self.burst_size))
|
||||
last_update = float(state.get("last_update", now))
|
||||
|
||||
elapsed = now - last_update
|
||||
tokens = min(self.burst_size, tokens + elapsed * self.refill_rate)
|
||||
|
||||
return RateLimitInfo(
|
||||
limit=self.limit,
|
||||
remaining=int(tokens),
|
||||
reset_at=now + (self.burst_size - tokens) / self.refill_rate,
|
||||
window_size=self.window_size,
|
||||
)
|
||||
|
||||
|
||||
class SlidingWindowAlgorithm(BaseAlgorithm):
|
||||
"""Sliding window log algorithm - precise but memory intensive."""
|
||||
|
||||
async def check(self, key: str) -> tuple[bool, RateLimitInfo]:
|
||||
now = time.time()
|
||||
window_start = now - self.window_size
|
||||
state = await self.backend.get(key)
|
||||
|
||||
timestamps: list[float] = []
|
||||
if state is not None:
|
||||
raw_timestamps = state.get("timestamps", [])
|
||||
timestamps = [
|
||||
float(ts) for ts in raw_timestamps if float(ts) > window_start
|
||||
]
|
||||
|
||||
if len(timestamps) < self.limit:
|
||||
timestamps.append(now)
|
||||
allowed = True
|
||||
retry_after = None
|
||||
else:
|
||||
allowed = False
|
||||
oldest = min(timestamps) if timestamps else now
|
||||
retry_after = oldest + self.window_size - now
|
||||
|
||||
await self.backend.set(
|
||||
key,
|
||||
{"timestamps": timestamps},
|
||||
ttl=self.window_size * 2,
|
||||
)
|
||||
|
||||
remaining = max(0, self.limit - len(timestamps))
|
||||
reset_at = (min(timestamps) if timestamps else now) + self.window_size
|
||||
|
||||
return allowed, RateLimitInfo(
|
||||
limit=self.limit,
|
||||
remaining=remaining,
|
||||
reset_at=reset_at,
|
||||
retry_after=retry_after,
|
||||
window_size=self.window_size,
|
||||
)
|
||||
|
||||
async def reset(self, key: str) -> None:
|
||||
await self.backend.delete(key)
|
||||
|
||||
async def get_state(self, key: str) -> RateLimitInfo | None:
|
||||
now = time.time()
|
||||
window_start = now - self.window_size
|
||||
state = await self.backend.get(key)
|
||||
|
||||
if state is None:
|
||||
return None
|
||||
|
||||
raw_timestamps = state.get("timestamps", [])
|
||||
timestamps = [float(ts) for ts in raw_timestamps if float(ts) > window_start]
|
||||
remaining = max(0, self.limit - len(timestamps))
|
||||
reset_at = (min(timestamps) if timestamps else now) + self.window_size
|
||||
|
||||
return RateLimitInfo(
|
||||
limit=self.limit,
|
||||
remaining=remaining,
|
||||
reset_at=reset_at,
|
||||
window_size=self.window_size,
|
||||
)
|
||||
|
||||
|
||||
class FixedWindowAlgorithm(BaseAlgorithm):
|
||||
"""Fixed window algorithm - simple and efficient."""
|
||||
|
||||
async def check(self, key: str) -> tuple[bool, RateLimitInfo]:
|
||||
now = time.time()
|
||||
window_start = (now // self.window_size) * self.window_size
|
||||
window_end = window_start + self.window_size
|
||||
state = await self.backend.get(key)
|
||||
|
||||
count = 0
|
||||
if state is not None:
|
||||
stored_window = float(state.get("window_start", 0))
|
||||
if stored_window == window_start:
|
||||
count = int(state.get("count", 0))
|
||||
|
||||
if count < self.limit:
|
||||
count += 1
|
||||
allowed = True
|
||||
retry_after = None
|
||||
else:
|
||||
allowed = False
|
||||
retry_after = window_end - now
|
||||
|
||||
await self.backend.set(
|
||||
key,
|
||||
{"count": count, "window_start": window_start},
|
||||
ttl=self.window_size * 2,
|
||||
)
|
||||
|
||||
return allowed, RateLimitInfo(
|
||||
limit=self.limit,
|
||||
remaining=max(0, self.limit - count),
|
||||
reset_at=window_end,
|
||||
retry_after=retry_after,
|
||||
window_size=self.window_size,
|
||||
)
|
||||
|
||||
async def reset(self, key: str) -> None:
|
||||
await self.backend.delete(key)
|
||||
|
||||
async def get_state(self, key: str) -> RateLimitInfo | None:
|
||||
now = time.time()
|
||||
window_start = (now // self.window_size) * self.window_size
|
||||
window_end = window_start + self.window_size
|
||||
state = await self.backend.get(key)
|
||||
|
||||
if state is None:
|
||||
return None
|
||||
|
||||
count = 0
|
||||
stored_window = float(state.get("window_start", 0))
|
||||
if stored_window == window_start:
|
||||
count = int(state.get("count", 0))
|
||||
|
||||
return RateLimitInfo(
|
||||
limit=self.limit,
|
||||
remaining=max(0, self.limit - count),
|
||||
reset_at=window_end,
|
||||
window_size=self.window_size,
|
||||
)
|
||||
|
||||
|
||||
class LeakyBucketAlgorithm(BaseAlgorithm):
|
||||
"""Leaky bucket algorithm - smooths out bursts."""
|
||||
|
||||
__slots__ = ("leak_rate",)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
limit: int,
|
||||
window_size: float,
|
||||
backend: Backend,
|
||||
*,
|
||||
burst_size: int | None = None,
|
||||
) -> None:
|
||||
super().__init__(limit, window_size, backend, burst_size=burst_size)
|
||||
self.leak_rate = limit / window_size
|
||||
|
||||
async def check(self, key: str) -> tuple[bool, RateLimitInfo]:
|
||||
now = time.time()
|
||||
state = await self.backend.get(key)
|
||||
|
||||
water_level = 0.0
|
||||
if state is not None:
|
||||
water_level = float(state.get("water_level", 0))
|
||||
last_update = float(state.get("last_update", now))
|
||||
elapsed = now - last_update
|
||||
water_level = max(0, water_level - elapsed * self.leak_rate)
|
||||
|
||||
if water_level < self.burst_size:
|
||||
water_level += 1
|
||||
allowed = True
|
||||
retry_after = None
|
||||
else:
|
||||
allowed = False
|
||||
retry_after = (water_level - self.burst_size + 1) / self.leak_rate
|
||||
|
||||
await self.backend.set(
|
||||
key,
|
||||
{"water_level": water_level, "last_update": now},
|
||||
ttl=self.window_size * 2,
|
||||
)
|
||||
|
||||
remaining = max(0, int(self.burst_size - water_level))
|
||||
reset_at = now + water_level / self.leak_rate
|
||||
|
||||
return allowed, RateLimitInfo(
|
||||
limit=self.limit,
|
||||
remaining=remaining,
|
||||
reset_at=reset_at,
|
||||
retry_after=retry_after,
|
||||
window_size=self.window_size,
|
||||
)
|
||||
|
||||
async def reset(self, key: str) -> None:
|
||||
await self.backend.delete(key)
|
||||
|
||||
async def get_state(self, key: str) -> RateLimitInfo | None:
|
||||
now = time.time()
|
||||
state = await self.backend.get(key)
|
||||
|
||||
if state is None:
|
||||
return None
|
||||
|
||||
water_level = float(state.get("water_level", 0))
|
||||
last_update = float(state.get("last_update", now))
|
||||
elapsed = now - last_update
|
||||
water_level = max(0, water_level - elapsed * self.leak_rate)
|
||||
|
||||
remaining = max(0, int(self.burst_size - water_level))
|
||||
reset_at = now + water_level / self.leak_rate
|
||||
|
||||
return RateLimitInfo(
|
||||
limit=self.limit,
|
||||
remaining=remaining,
|
||||
reset_at=reset_at,
|
||||
window_size=self.window_size,
|
||||
)
|
||||
|
||||
|
||||
class SlidingWindowCounterAlgorithm(BaseAlgorithm):
|
||||
"""Sliding window counter - balance between precision and memory."""
|
||||
|
||||
async def check(self, key: str) -> tuple[bool, RateLimitInfo]:
|
||||
now = time.time()
|
||||
current_window = (now // self.window_size) * self.window_size
|
||||
previous_window = current_window - self.window_size
|
||||
window_progress = (now - current_window) / self.window_size
|
||||
|
||||
state = await self.backend.get(key)
|
||||
|
||||
prev_count = 0
|
||||
curr_count = 0
|
||||
if state is not None:
|
||||
prev_count = int(state.get("prev_count", 0))
|
||||
curr_count = int(state.get("curr_count", 0))
|
||||
stored_window = float(state.get("current_window", 0))
|
||||
|
||||
if stored_window < previous_window:
|
||||
prev_count = 0
|
||||
curr_count = 0
|
||||
elif stored_window == previous_window:
|
||||
prev_count = curr_count
|
||||
curr_count = 0
|
||||
|
||||
weighted_count = prev_count * (1 - window_progress) + curr_count
|
||||
|
||||
if weighted_count < self.limit:
|
||||
curr_count += 1
|
||||
allowed = True
|
||||
retry_after = None
|
||||
else:
|
||||
allowed = False
|
||||
retry_after = self.window_size * (1 - window_progress)
|
||||
|
||||
await self.backend.set(
|
||||
key,
|
||||
{
|
||||
"prev_count": prev_count,
|
||||
"curr_count": curr_count,
|
||||
"current_window": current_window,
|
||||
},
|
||||
ttl=self.window_size * 3,
|
||||
)
|
||||
|
||||
remaining = max(0, int(self.limit - weighted_count))
|
||||
reset_at = current_window + self.window_size
|
||||
|
||||
return allowed, RateLimitInfo(
|
||||
limit=self.limit,
|
||||
remaining=remaining,
|
||||
reset_at=reset_at,
|
||||
retry_after=retry_after,
|
||||
window_size=self.window_size,
|
||||
)
|
||||
|
||||
async def reset(self, key: str) -> None:
|
||||
await self.backend.delete(key)
|
||||
|
||||
async def get_state(self, key: str) -> RateLimitInfo | None:
|
||||
now = time.time()
|
||||
current_window = (now // self.window_size) * self.window_size
|
||||
previous_window = current_window - self.window_size
|
||||
window_progress = (now - current_window) / self.window_size
|
||||
|
||||
state = await self.backend.get(key)
|
||||
|
||||
if state is None:
|
||||
return None
|
||||
|
||||
prev_count = int(state.get("prev_count", 0))
|
||||
curr_count = int(state.get("curr_count", 0))
|
||||
stored_window = float(state.get("current_window", 0))
|
||||
|
||||
if stored_window < previous_window:
|
||||
prev_count = 0
|
||||
curr_count = 0
|
||||
elif stored_window == previous_window:
|
||||
prev_count = curr_count
|
||||
curr_count = 0
|
||||
|
||||
weighted_count = prev_count * (1 - window_progress) + curr_count
|
||||
remaining = max(0, int(self.limit - weighted_count))
|
||||
reset_at = current_window + self.window_size
|
||||
|
||||
return RateLimitInfo(
|
||||
limit=self.limit,
|
||||
remaining=remaining,
|
||||
reset_at=reset_at,
|
||||
window_size=self.window_size,
|
||||
)
|
||||
|
||||
|
||||
def get_algorithm(
|
||||
algorithm: Algorithm,
|
||||
limit: int,
|
||||
window_size: float,
|
||||
backend: Backend,
|
||||
*,
|
||||
burst_size: int | None = None,
|
||||
) -> BaseAlgorithm:
|
||||
"""Factory function to create algorithm instances."""
|
||||
algorithm_map: dict[Algorithm, type[BaseAlgorithm]] = {
|
||||
Algorithm.TOKEN_BUCKET: TokenBucketAlgorithm,
|
||||
Algorithm.SLIDING_WINDOW: SlidingWindowAlgorithm,
|
||||
Algorithm.FIXED_WINDOW: FixedWindowAlgorithm,
|
||||
Algorithm.LEAKY_BUCKET: LeakyBucketAlgorithm,
|
||||
Algorithm.SLIDING_WINDOW_COUNTER: SlidingWindowCounterAlgorithm,
|
||||
}
|
||||
|
||||
algorithm_class = algorithm_map.get(algorithm)
|
||||
if algorithm_class is None:
|
||||
msg = f"Unknown algorithm: {algorithm}"
|
||||
raise ValueError(msg)
|
||||
|
||||
return algorithm_class(limit, window_size, backend, burst_size=burst_size)
|
||||
81
fastapi_traffic/core/config.py
Normal file
81
fastapi_traffic/core/config.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""Configuration for rate limiting."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
from fastapi_traffic.core.algorithms import Algorithm
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from starlette.requests import Request
|
||||
|
||||
from fastapi_traffic.backends.base import Backend
|
||||
|
||||
|
||||
KeyExtractor = Callable[["Request"], str]
|
||||
|
||||
|
||||
def default_key_extractor(request: Request) -> str:
|
||||
"""Extract client IP as the default rate limit key."""
|
||||
forwarded = request.headers.get("X-Forwarded-For")
|
||||
if forwarded:
|
||||
return forwarded.split(",")[0].strip()
|
||||
|
||||
real_ip = request.headers.get("X-Real-IP")
|
||||
if real_ip:
|
||||
return real_ip
|
||||
|
||||
if request.client:
|
||||
return request.client.host
|
||||
|
||||
return "unknown"
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class RateLimitConfig:
|
||||
"""Configuration for a rate limit rule."""
|
||||
|
||||
limit: int
|
||||
window_size: float = 60.0
|
||||
algorithm: Algorithm = Algorithm.SLIDING_WINDOW_COUNTER
|
||||
key_prefix: str = "ratelimit"
|
||||
key_extractor: KeyExtractor = field(default=default_key_extractor)
|
||||
burst_size: int | None = None
|
||||
include_headers: bool = True
|
||||
error_message: str = "Rate limit exceeded"
|
||||
status_code: int = 429
|
||||
skip_on_error: bool = False
|
||||
cost: int = 1
|
||||
exempt_when: Callable[[Request], bool] | None = None
|
||||
on_blocked: Callable[[Request, Any], Any] | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.limit <= 0:
|
||||
msg = "limit must be positive"
|
||||
raise ValueError(msg)
|
||||
if self.window_size <= 0:
|
||||
msg = "window_size must be positive"
|
||||
raise ValueError(msg)
|
||||
if self.cost <= 0:
|
||||
msg = "cost must be positive"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class GlobalConfig:
|
||||
"""Global configuration for the rate limiter."""
|
||||
|
||||
backend: Backend | None = None
|
||||
enabled: bool = True
|
||||
default_limit: int = 100
|
||||
default_window_size: float = 60.0
|
||||
default_algorithm: Algorithm = Algorithm.SLIDING_WINDOW_COUNTER
|
||||
key_prefix: str = "fastapi_traffic"
|
||||
include_headers: bool = True
|
||||
error_message: str = "Rate limit exceeded. Please try again later."
|
||||
status_code: int = 429
|
||||
skip_on_error: bool = False
|
||||
exempt_ips: set[str] = field(default_factory=set)
|
||||
exempt_paths: set[str] = field(default_factory=set)
|
||||
headers_prefix: str = "X-RateLimit"
|
||||
259
fastapi_traffic/core/decorator.py
Normal file
259
fastapi_traffic/core/decorator.py
Normal file
@@ -0,0 +1,259 @@
|
||||
"""Rate limit decorator for FastAPI endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from typing import TYPE_CHECKING, Any, Callable, TypeVar, overload
|
||||
|
||||
from fastapi_traffic.core.algorithms import Algorithm
|
||||
from fastapi_traffic.core.config import KeyExtractor, RateLimitConfig, default_key_extractor
|
||||
from fastapi_traffic.core.limiter import get_limiter
|
||||
from fastapi_traffic.exceptions import RateLimitExceeded
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
@overload
|
||||
def rate_limit(
|
||||
limit: int,
|
||||
*,
|
||||
window_size: float = ...,
|
||||
algorithm: Algorithm = ...,
|
||||
key_prefix: str = ...,
|
||||
key_extractor: KeyExtractor = ...,
|
||||
burst_size: int | None = ...,
|
||||
include_headers: bool = ...,
|
||||
error_message: str = ...,
|
||||
status_code: int = ...,
|
||||
skip_on_error: bool = ...,
|
||||
cost: int = ...,
|
||||
exempt_when: Callable[[Request], bool] | None = ...,
|
||||
on_blocked: Callable[[Request, Any], Any] | None = ...,
|
||||
) -> Callable[[F], F]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def rate_limit(
|
||||
limit: int,
|
||||
window_size: float,
|
||||
/,
|
||||
) -> Callable[[F], F]: ...
|
||||
|
||||
|
||||
def rate_limit(
|
||||
limit: int,
|
||||
window_size: float = 60.0,
|
||||
*,
|
||||
algorithm: Algorithm = Algorithm.SLIDING_WINDOW_COUNTER,
|
||||
key_prefix: str = "ratelimit",
|
||||
key_extractor: KeyExtractor = default_key_extractor,
|
||||
burst_size: int | None = None,
|
||||
include_headers: bool = True,
|
||||
error_message: str = "Rate limit exceeded",
|
||||
status_code: int = 429,
|
||||
skip_on_error: bool = False,
|
||||
cost: int = 1,
|
||||
exempt_when: Callable[[Request], bool] | None = None,
|
||||
on_blocked: Callable[[Request, Any], Any] | None = None,
|
||||
) -> Callable[[F], F]:
|
||||
"""Decorator to apply rate limiting to a FastAPI endpoint.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of requests allowed in the window.
|
||||
window_size: Time window in seconds.
|
||||
algorithm: Rate limiting algorithm to use.
|
||||
key_prefix: Prefix for the rate limit key.
|
||||
key_extractor: Function to extract the client identifier from request.
|
||||
burst_size: Maximum burst size (for token bucket/leaky bucket).
|
||||
include_headers: Whether to include rate limit headers in response.
|
||||
error_message: Error message when rate limit is exceeded.
|
||||
status_code: HTTP status code when rate limit is exceeded.
|
||||
skip_on_error: Skip rate limiting if backend errors occur.
|
||||
cost: Cost of each request (default 1).
|
||||
exempt_when: Function to determine if request should be exempt.
|
||||
on_blocked: Callback when a request is blocked.
|
||||
|
||||
Returns:
|
||||
Decorated function with rate limiting applied.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from fastapi import FastAPI
|
||||
from fastapi_traffic import rate_limit
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/api/resource")
|
||||
@rate_limit(100, 60) # 100 requests per minute
|
||||
async def get_resource():
|
||||
return {"message": "Hello"}
|
||||
```
|
||||
"""
|
||||
config = RateLimitConfig(
|
||||
limit=limit,
|
||||
window_size=window_size,
|
||||
algorithm=algorithm,
|
||||
key_prefix=key_prefix,
|
||||
key_extractor=key_extractor,
|
||||
burst_size=burst_size,
|
||||
include_headers=include_headers,
|
||||
error_message=error_message,
|
||||
status_code=status_code,
|
||||
skip_on_error=skip_on_error,
|
||||
cost=cost,
|
||||
exempt_when=exempt_when,
|
||||
on_blocked=on_blocked,
|
||||
)
|
||||
|
||||
def decorator(func: F) -> F:
|
||||
@functools.wraps(func)
|
||||
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
request = _extract_request(args, kwargs)
|
||||
if request is None:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
limiter = get_limiter()
|
||||
result = await limiter.hit(request, config)
|
||||
|
||||
response = await func(*args, **kwargs)
|
||||
|
||||
if config.include_headers and hasattr(response, "headers"):
|
||||
for key, value in result.info.to_headers().items():
|
||||
response.headers[key] = value
|
||||
|
||||
return response
|
||||
|
||||
@functools.wraps(func)
|
||||
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
import asyncio
|
||||
|
||||
return asyncio.get_event_loop().run_until_complete(
|
||||
async_wrapper(*args, **kwargs)
|
||||
)
|
||||
|
||||
if _is_coroutine_function(func):
|
||||
return async_wrapper # type: ignore[return-value]
|
||||
return sync_wrapper # type: ignore[return-value]
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _extract_request(
|
||||
args: tuple[Any, ...],
|
||||
kwargs: dict[str, Any],
|
||||
) -> Request | None:
|
||||
"""Extract the Request object from function arguments."""
|
||||
from starlette.requests import Request
|
||||
|
||||
for arg in args:
|
||||
if isinstance(arg, Request):
|
||||
return arg
|
||||
|
||||
for value in kwargs.values():
|
||||
if isinstance(value, Request):
|
||||
return value
|
||||
|
||||
if "request" in kwargs:
|
||||
req = kwargs["request"]
|
||||
if isinstance(req, Request):
|
||||
return req
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _is_coroutine_function(func: Callable[..., Any]) -> bool:
|
||||
"""Check if a function is a coroutine function."""
|
||||
import asyncio
|
||||
import inspect
|
||||
|
||||
return asyncio.iscoroutinefunction(func) or inspect.iscoroutinefunction(func)
|
||||
|
||||
|
||||
class RateLimitDependency:
|
||||
"""FastAPI dependency for rate limiting.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from fastapi import FastAPI, Depends
|
||||
from fastapi_traffic import RateLimitDependency
|
||||
|
||||
app = FastAPI()
|
||||
rate_limiter = RateLimitDependency(limit=100, window_size=60)
|
||||
|
||||
@app.get("/api/resource")
|
||||
async def get_resource(rate_limit_info = Depends(rate_limiter)):
|
||||
return {"remaining": rate_limit_info.remaining}
|
||||
```
|
||||
"""
|
||||
|
||||
__slots__ = ("_config",)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
limit: int,
|
||||
window_size: float = 60.0,
|
||||
*,
|
||||
algorithm: Algorithm = Algorithm.SLIDING_WINDOW_COUNTER,
|
||||
key_prefix: str = "ratelimit",
|
||||
key_extractor: KeyExtractor = default_key_extractor,
|
||||
burst_size: int | None = None,
|
||||
error_message: str = "Rate limit exceeded",
|
||||
status_code: int = 429,
|
||||
skip_on_error: bool = False,
|
||||
cost: int = 1,
|
||||
exempt_when: Callable[[Request], bool] | None = None,
|
||||
) -> None:
|
||||
self._config = RateLimitConfig(
|
||||
limit=limit,
|
||||
window_size=window_size,
|
||||
algorithm=algorithm,
|
||||
key_prefix=key_prefix,
|
||||
key_extractor=key_extractor,
|
||||
burst_size=burst_size,
|
||||
include_headers=True,
|
||||
error_message=error_message,
|
||||
status_code=status_code,
|
||||
skip_on_error=skip_on_error,
|
||||
cost=cost,
|
||||
exempt_when=exempt_when,
|
||||
)
|
||||
|
||||
async def __call__(self, request: Request) -> Any:
|
||||
"""Check rate limit and return info."""
|
||||
limiter = get_limiter()
|
||||
result = await limiter.hit(request, self._config)
|
||||
return result.info
|
||||
|
||||
|
||||
def create_rate_limit_response(
|
||||
exc: RateLimitExceeded,
|
||||
*,
|
||||
include_headers: bool = True,
|
||||
) -> Response:
|
||||
"""Create a rate limit exceeded response.
|
||||
|
||||
Args:
|
||||
exc: The RateLimitExceeded exception.
|
||||
include_headers: Whether to include rate limit headers.
|
||||
|
||||
Returns:
|
||||
A JSONResponse with rate limit information.
|
||||
"""
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
headers: dict[str, str] = {}
|
||||
if include_headers and exc.limit_info is not None:
|
||||
headers = exc.limit_info.to_headers()
|
||||
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"detail": exc.message,
|
||||
"retry_after": exc.retry_after,
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
301
fastapi_traffic/core/limiter.py
Normal file
301
fastapi_traffic/core/limiter.py
Normal file
@@ -0,0 +1,301 @@
|
||||
"""Core rate limiter implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi_traffic.backends.memory import MemoryBackend
|
||||
from fastapi_traffic.core.algorithms import Algorithm, BaseAlgorithm, get_algorithm
|
||||
from fastapi_traffic.core.config import GlobalConfig, RateLimitConfig
|
||||
from fastapi_traffic.core.models import RateLimitInfo, RateLimitResult
|
||||
from fastapi_traffic.exceptions import BackendError, RateLimitExceeded
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from starlette.requests import Request
|
||||
|
||||
from fastapi_traffic.backends.base import Backend
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Main rate limiter class that manages rate limiting logic."""
|
||||
|
||||
__slots__ = ("_config", "_backend", "_algorithms", "_initialized")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backend: Backend | None = None,
|
||||
*,
|
||||
config: GlobalConfig | None = None,
|
||||
) -> None:
|
||||
"""Initialize the rate limiter.
|
||||
|
||||
Args:
|
||||
backend: Storage backend for rate limit data.
|
||||
config: Global configuration options.
|
||||
"""
|
||||
self._config = config or GlobalConfig()
|
||||
self._backend = backend or self._config.backend or MemoryBackend()
|
||||
self._algorithms: dict[str, BaseAlgorithm] = {}
|
||||
self._initialized = False
|
||||
|
||||
@property
|
||||
def backend(self) -> Backend:
|
||||
"""Get the storage backend."""
|
||||
return self._backend
|
||||
|
||||
@property
|
||||
def config(self) -> GlobalConfig:
|
||||
"""Get the global configuration."""
|
||||
return self._config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the rate limiter and backend."""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
if hasattr(self._backend, "initialize"):
|
||||
await self._backend.initialize() # type: ignore[union-attr]
|
||||
|
||||
if hasattr(self._backend, "start_cleanup"):
|
||||
await self._backend.start_cleanup() # type: ignore[union-attr]
|
||||
|
||||
self._initialized = True
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the rate limiter and cleanup resources."""
|
||||
await self._backend.close()
|
||||
self._algorithms.clear()
|
||||
self._initialized = False
|
||||
|
||||
def _get_algorithm(
|
||||
self,
|
||||
limit: int,
|
||||
window_size: float,
|
||||
algorithm: Algorithm,
|
||||
burst_size: int | None = None,
|
||||
) -> BaseAlgorithm:
|
||||
"""Get or create an algorithm instance."""
|
||||
cache_key = f"{algorithm.value}:{limit}:{window_size}:{burst_size}"
|
||||
if cache_key not in self._algorithms:
|
||||
self._algorithms[cache_key] = get_algorithm(
|
||||
algorithm,
|
||||
limit,
|
||||
window_size,
|
||||
self._backend,
|
||||
burst_size=burst_size,
|
||||
)
|
||||
return self._algorithms[cache_key]
|
||||
|
||||
def _build_key(
|
||||
self,
|
||||
request: Request,
|
||||
config: RateLimitConfig,
|
||||
identifier: str | None = None,
|
||||
) -> str:
|
||||
"""Build the rate limit key for a request."""
|
||||
if identifier:
|
||||
client_id = identifier
|
||||
else:
|
||||
client_id = config.key_extractor(request)
|
||||
|
||||
path = request.url.path
|
||||
method = request.method
|
||||
|
||||
return f"{self._config.key_prefix}:{config.key_prefix}:{method}:{path}:{client_id}"
|
||||
|
||||
def _is_exempt(self, request: Request, config: RateLimitConfig) -> bool:
|
||||
"""Check if the request is exempt from rate limiting."""
|
||||
if not self._config.enabled:
|
||||
return True
|
||||
|
||||
if config.exempt_when is not None and config.exempt_when(request):
|
||||
return True
|
||||
|
||||
client_ip = config.key_extractor(request)
|
||||
if client_ip in self._config.exempt_ips:
|
||||
return True
|
||||
|
||||
if request.url.path in self._config.exempt_paths:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def check(
|
||||
self,
|
||||
request: Request,
|
||||
config: RateLimitConfig,
|
||||
*,
|
||||
identifier: str | None = None,
|
||||
cost: int | None = None,
|
||||
) -> RateLimitResult:
|
||||
"""Check if a request is allowed under the rate limit.
|
||||
|
||||
Args:
|
||||
request: The incoming request.
|
||||
config: Rate limit configuration for this endpoint.
|
||||
identifier: Optional custom identifier override.
|
||||
cost: Optional cost override for this request.
|
||||
|
||||
Returns:
|
||||
RateLimitResult with allowed status and limit info.
|
||||
"""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
if self._is_exempt(request, config):
|
||||
return RateLimitResult(
|
||||
allowed=True,
|
||||
info=RateLimitInfo(
|
||||
limit=config.limit,
|
||||
remaining=config.limit,
|
||||
reset_at=0,
|
||||
window_size=config.window_size,
|
||||
),
|
||||
key="exempt",
|
||||
)
|
||||
|
||||
key = self._build_key(request, config, identifier)
|
||||
actual_cost = cost or config.cost
|
||||
|
||||
try:
|
||||
algorithm = self._get_algorithm(
|
||||
config.limit,
|
||||
config.window_size,
|
||||
config.algorithm,
|
||||
config.burst_size,
|
||||
)
|
||||
|
||||
info: RateLimitInfo | None = None
|
||||
for _ in range(actual_cost):
|
||||
allowed, info = await algorithm.check(key)
|
||||
if not allowed:
|
||||
return RateLimitResult(allowed=False, info=info, key=key)
|
||||
|
||||
if info is None:
|
||||
info = RateLimitInfo(
|
||||
limit=config.limit,
|
||||
remaining=config.limit,
|
||||
reset_at=0,
|
||||
window_size=config.window_size,
|
||||
)
|
||||
return RateLimitResult(allowed=True, info=info, key=key)
|
||||
|
||||
except BackendError as e:
|
||||
logger.warning("Backend error during rate limit check: %s", e)
|
||||
if config.skip_on_error:
|
||||
return RateLimitResult(
|
||||
allowed=True,
|
||||
info=RateLimitInfo(
|
||||
limit=config.limit,
|
||||
remaining=config.limit,
|
||||
reset_at=0,
|
||||
window_size=config.window_size,
|
||||
),
|
||||
key=key,
|
||||
)
|
||||
raise
|
||||
|
||||
async def hit(
|
||||
self,
|
||||
request: Request,
|
||||
config: RateLimitConfig,
|
||||
*,
|
||||
identifier: str | None = None,
|
||||
cost: int | None = None,
|
||||
) -> RateLimitResult:
|
||||
"""Check rate limit and raise exception if exceeded.
|
||||
|
||||
Args:
|
||||
request: The incoming request.
|
||||
config: Rate limit configuration for this endpoint.
|
||||
identifier: Optional custom identifier override.
|
||||
cost: Optional cost override for this request.
|
||||
|
||||
Returns:
|
||||
RateLimitResult if allowed.
|
||||
|
||||
Raises:
|
||||
RateLimitExceeded: If the rate limit is exceeded.
|
||||
"""
|
||||
result = await self.check(request, config, identifier=identifier, cost=cost)
|
||||
|
||||
if not result.allowed:
|
||||
if config.on_blocked is not None:
|
||||
config.on_blocked(request, result)
|
||||
|
||||
raise RateLimitExceeded(
|
||||
config.error_message,
|
||||
retry_after=result.info.retry_after,
|
||||
limit_info=result.info,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def reset(
|
||||
self,
|
||||
request: Request,
|
||||
config: RateLimitConfig,
|
||||
*,
|
||||
identifier: str | None = None,
|
||||
) -> None:
|
||||
"""Reset the rate limit for a specific key.
|
||||
|
||||
Args:
|
||||
request: The request to reset limits for.
|
||||
config: Rate limit configuration.
|
||||
identifier: Optional custom identifier override.
|
||||
"""
|
||||
key = self._build_key(request, config, identifier)
|
||||
algorithm = self._get_algorithm(
|
||||
config.limit,
|
||||
config.window_size,
|
||||
config.algorithm,
|
||||
config.burst_size,
|
||||
)
|
||||
await algorithm.reset(key)
|
||||
|
||||
async def get_state(
|
||||
self,
|
||||
request: Request,
|
||||
config: RateLimitConfig,
|
||||
*,
|
||||
identifier: str | None = None,
|
||||
) -> RateLimitInfo | None:
|
||||
"""Get the current rate limit state without consuming a token.
|
||||
|
||||
Args:
|
||||
request: The request to check.
|
||||
config: Rate limit configuration.
|
||||
identifier: Optional custom identifier override.
|
||||
|
||||
Returns:
|
||||
Current rate limit info or None if no state exists.
|
||||
"""
|
||||
key = self._build_key(request, config, identifier)
|
||||
algorithm = self._get_algorithm(
|
||||
config.limit,
|
||||
config.window_size,
|
||||
config.algorithm,
|
||||
config.burst_size,
|
||||
)
|
||||
return await algorithm.get_state(key)
|
||||
|
||||
|
||||
_default_limiter: RateLimiter | None = None
|
||||
|
||||
|
||||
def get_limiter() -> RateLimiter:
|
||||
"""Get the default rate limiter instance."""
|
||||
global _default_limiter
|
||||
if _default_limiter is None:
|
||||
_default_limiter = RateLimiter()
|
||||
return _default_limiter
|
||||
|
||||
|
||||
def set_limiter(limiter: RateLimiter) -> None:
|
||||
"""Set the default rate limiter instance."""
|
||||
global _default_limiter
|
||||
_default_limiter = limiter
|
||||
89
fastapi_traffic/core/models.py
Normal file
89
fastapi_traffic/core/models.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""Data models for rate limiting."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
|
||||
class KeyType(str, Enum):
|
||||
"""Type of key extraction for rate limiting."""
|
||||
|
||||
IP = "ip"
|
||||
USER = "user"
|
||||
API_KEY = "api_key"
|
||||
ENDPOINT = "endpoint"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class RateLimitInfo:
|
||||
"""Information about the current rate limit state."""
|
||||
|
||||
limit: int
|
||||
remaining: int
|
||||
reset_at: float
|
||||
retry_after: float | None = None
|
||||
window_size: float = 60.0
|
||||
|
||||
def to_headers(self) -> dict[str, str]:
|
||||
"""Convert rate limit info to HTTP headers."""
|
||||
headers: dict[str, str] = {
|
||||
"X-RateLimit-Limit": str(self.limit),
|
||||
"X-RateLimit-Remaining": str(max(0, self.remaining)),
|
||||
"X-RateLimit-Reset": str(int(self.reset_at)),
|
||||
}
|
||||
if self.retry_after is not None:
|
||||
headers["Retry-After"] = str(int(self.retry_after))
|
||||
return headers
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class RateLimitResult:
|
||||
"""Result of a rate limit check."""
|
||||
|
||||
allowed: bool
|
||||
info: RateLimitInfo
|
||||
key: str
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class TokenBucketState:
|
||||
"""State for token bucket algorithm."""
|
||||
|
||||
tokens: float
|
||||
last_update: float
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class SlidingWindowState:
|
||||
"""State for sliding window algorithm."""
|
||||
|
||||
timestamps: list[float] = field(default_factory=list)
|
||||
count: int = 0
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class FixedWindowState:
|
||||
"""State for fixed window algorithm."""
|
||||
|
||||
count: int
|
||||
window_start: float
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class LeakyBucketState:
|
||||
"""State for leaky bucket algorithm."""
|
||||
|
||||
water_level: float
|
||||
last_update: float
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class BackendRecord:
|
||||
"""Generic record stored in backends."""
|
||||
|
||||
key: str
|
||||
data: dict[str, Any]
|
||||
expires_at: float
|
||||
50
fastapi_traffic/exceptions.py
Normal file
50
fastapi_traffic/exceptions.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""Custom exceptions for FastAPI Traffic rate limiter."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi_traffic.core.models import RateLimitInfo
|
||||
|
||||
|
||||
class FastAPITrafficError(Exception):
|
||||
"""Base exception for all FastAPI Traffic errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RateLimitExceeded(FastAPITrafficError):
|
||||
"""Raised when a rate limit has been exceeded."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Rate limit exceeded",
|
||||
*,
|
||||
retry_after: float | None = None,
|
||||
limit_info: RateLimitInfo | None = None,
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.retry_after = retry_after
|
||||
self.limit_info = limit_info
|
||||
|
||||
|
||||
class BackendError(FastAPITrafficError):
|
||||
"""Raised when a backend operation fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Backend operation failed",
|
||||
*,
|
||||
original_error: Exception | None = None,
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.original_error = original_error
|
||||
|
||||
|
||||
class ConfigurationError(FastAPITrafficError):
|
||||
"""Raised when there is a configuration error."""
|
||||
|
||||
pass
|
||||
184
fastapi_traffic/middleware.py
Normal file
184
fastapi_traffic/middleware.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""Rate limiting middleware for Starlette/FastAPI applications."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Awaitable, Callable
|
||||
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
from fastapi_traffic.backends.memory import MemoryBackend
|
||||
from fastapi_traffic.core.algorithms import Algorithm
|
||||
from fastapi_traffic.core.config import GlobalConfig, RateLimitConfig, default_key_extractor
|
||||
from fastapi_traffic.core.limiter import RateLimiter
|
||||
from fastapi_traffic.exceptions import RateLimitExceeded
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
from fastapi_traffic.backends.base import Backend
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware for global rate limiting across all endpoints."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
*,
|
||||
limit: int = 100,
|
||||
window_size: float = 60.0,
|
||||
algorithm: Algorithm = Algorithm.SLIDING_WINDOW_COUNTER,
|
||||
backend: Backend | None = None,
|
||||
key_prefix: str = "middleware",
|
||||
include_headers: bool = True,
|
||||
error_message: str = "Rate limit exceeded. Please try again later.",
|
||||
status_code: int = 429,
|
||||
skip_on_error: bool = False,
|
||||
exempt_paths: set[str] | None = None,
|
||||
exempt_ips: set[str] | None = None,
|
||||
key_extractor: Callable[[Request], str] = default_key_extractor,
|
||||
) -> None:
|
||||
"""Initialize the rate limit middleware.
|
||||
|
||||
Args:
|
||||
app: The ASGI application.
|
||||
limit: Maximum requests per window.
|
||||
window_size: Time window in seconds.
|
||||
algorithm: Rate limiting algorithm.
|
||||
backend: Storage backend (defaults to MemoryBackend).
|
||||
key_prefix: Prefix for rate limit keys.
|
||||
include_headers: Include rate limit headers in response.
|
||||
error_message: Error message when rate limited.
|
||||
status_code: HTTP status code when rate limited.
|
||||
skip_on_error: Skip rate limiting on backend errors.
|
||||
exempt_paths: Paths to exempt from rate limiting.
|
||||
exempt_ips: IP addresses to exempt from rate limiting.
|
||||
key_extractor: Function to extract client identifier.
|
||||
"""
|
||||
super().__init__(app)
|
||||
|
||||
self._backend = backend or MemoryBackend()
|
||||
self._config = RateLimitConfig(
|
||||
limit=limit,
|
||||
window_size=window_size,
|
||||
algorithm=algorithm,
|
||||
key_prefix=key_prefix,
|
||||
key_extractor=key_extractor,
|
||||
include_headers=include_headers,
|
||||
error_message=error_message,
|
||||
status_code=status_code,
|
||||
skip_on_error=skip_on_error,
|
||||
)
|
||||
|
||||
global_config = GlobalConfig(
|
||||
backend=self._backend,
|
||||
exempt_paths=exempt_paths or set(),
|
||||
exempt_ips=exempt_ips or set(),
|
||||
)
|
||||
|
||||
self._limiter = RateLimiter(self._backend, config=global_config)
|
||||
self._include_headers = include_headers
|
||||
self._error_message = error_message
|
||||
self._status_code = status_code
|
||||
|
||||
async def dispatch(
|
||||
self,
|
||||
request: Request,
|
||||
call_next: Callable[[Request], Awaitable[Response]],
|
||||
) -> Response:
|
||||
"""Process the request with rate limiting."""
|
||||
try:
|
||||
result = await self._limiter.check(request, self._config)
|
||||
|
||||
if not result.allowed:
|
||||
return self._create_rate_limit_response(result)
|
||||
|
||||
response = await call_next(request)
|
||||
|
||||
if self._include_headers:
|
||||
for key, value in result.info.to_headers().items():
|
||||
response.headers[key] = value
|
||||
|
||||
return response
|
||||
|
||||
except RateLimitExceeded as exc:
|
||||
return JSONResponse(
|
||||
status_code=self._status_code,
|
||||
content={
|
||||
"detail": exc.message,
|
||||
"retry_after": exc.retry_after,
|
||||
},
|
||||
headers=exc.limit_info.to_headers() if exc.limit_info else {},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error in rate limit middleware: %s", e)
|
||||
if self._config.skip_on_error:
|
||||
return await call_next(request)
|
||||
raise
|
||||
|
||||
def _create_rate_limit_response(self, result: object) -> JSONResponse:
|
||||
"""Create a rate limit exceeded response."""
|
||||
from fastapi_traffic.core.models import RateLimitResult
|
||||
|
||||
if isinstance(result, RateLimitResult):
|
||||
headers = result.info.to_headers()
|
||||
retry_after = result.info.retry_after
|
||||
else:
|
||||
headers = {}
|
||||
retry_after = None
|
||||
|
||||
return JSONResponse(
|
||||
status_code=self._status_code,
|
||||
content={
|
||||
"detail": self._error_message,
|
||||
"retry_after": retry_after,
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
|
||||
class SlidingWindowMiddleware(RateLimitMiddleware):
|
||||
"""Convenience middleware using sliding window algorithm."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
*,
|
||||
limit: int = 100,
|
||||
window_size: float = 60.0,
|
||||
**kwargs: object,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
app,
|
||||
limit=limit,
|
||||
window_size=window_size,
|
||||
algorithm=Algorithm.SLIDING_WINDOW,
|
||||
**kwargs, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
class TokenBucketMiddleware(RateLimitMiddleware):
|
||||
"""Convenience middleware using token bucket algorithm."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
*,
|
||||
limit: int = 100,
|
||||
window_size: float = 60.0,
|
||||
**kwargs: object,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
app,
|
||||
limit=limit,
|
||||
window_size=window_size,
|
||||
algorithm=Algorithm.TOKEN_BUCKET,
|
||||
**kwargs, # type: ignore[arg-type]
|
||||
)
|
||||
0
fastapi_traffic/py.typed
Normal file
0
fastapi_traffic/py.typed
Normal file
Reference in New Issue
Block a user