Initial commit: fastapi-traffic rate limiting library

- Core rate limiting with multiple algorithms (sliding window, token bucket, etc.)
- SQLite and memory backends
- Decorator and dependency injection patterns
- Middleware support
- Example usage files
This commit is contained in:
2026-01-09 00:26:19 +00:00
commit da496746bb
38 changed files with 5790 additions and 0 deletions

View File

@@ -0,0 +1,19 @@
"""Backend implementations for rate limit storage."""
from fastapi_traffic.backends.base import Backend
from fastapi_traffic.backends.memory import MemoryBackend
from fastapi_traffic.backends.sqlite import SQLiteBackend
__all__ = [
"Backend",
"MemoryBackend",
"SQLiteBackend",
]
# Optional Redis backend
try:
from fastapi_traffic.backends.redis import RedisBackend
__all__.append("RedisBackend")
except ImportError:
pass

View File

@@ -0,0 +1,89 @@
"""Abstract base class for rate limit storage backends."""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any
class Backend(ABC):
"""Abstract base class for rate limit storage backends."""
@abstractmethod
async def get(self, key: str) -> dict[str, Any] | None:
"""Get the current state for a key.
Args:
key: The rate limit key.
Returns:
The stored state dictionary or None if not found.
"""
...
@abstractmethod
async def set(self, key: str, value: dict[str, Any], *, ttl: float) -> None:
"""Set the state for a key with TTL.
Args:
key: The rate limit key.
value: The state dictionary to store.
ttl: Time-to-live in seconds.
"""
...
@abstractmethod
async def delete(self, key: str) -> None:
"""Delete the state for a key.
Args:
key: The rate limit key.
"""
...
@abstractmethod
async def exists(self, key: str) -> bool:
"""Check if a key exists.
Args:
key: The rate limit key.
Returns:
True if the key exists, False otherwise.
"""
...
@abstractmethod
async def increment(self, key: str, amount: int = 1) -> int:
"""Atomically increment a counter.
Args:
key: The rate limit key.
amount: The amount to increment by.
Returns:
The new value after incrementing.
"""
...
@abstractmethod
async def clear(self) -> None:
"""Clear all rate limit data."""
...
async def close(self) -> None:
"""Close the backend connection."""
pass
async def __aenter__(self) -> Backend:
"""Async context manager entry."""
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: Any,
) -> None:
"""Async context manager exit."""
await self.close()

View File

@@ -0,0 +1,139 @@
"""In-memory backend for rate limiting - suitable for single-process applications."""
from __future__ import annotations
import asyncio
import time
from collections import OrderedDict
from typing import Any
from fastapi_traffic.backends.base import Backend
class MemoryBackend(Backend):
"""Thread-safe in-memory backend with LRU eviction and TTL support."""
__slots__ = ("_data", "_lock", "_max_size", "_cleanup_interval", "_cleanup_task")
def __init__(
self,
*,
max_size: int = 10000,
cleanup_interval: float = 60.0,
) -> None:
"""Initialize the memory backend.
Args:
max_size: Maximum number of entries to store (LRU eviction).
cleanup_interval: Interval in seconds for cleaning expired entries.
"""
self._data: OrderedDict[str, tuple[dict[str, Any], float]] = OrderedDict()
self._lock = asyncio.Lock()
self._max_size = max_size
self._cleanup_interval = cleanup_interval
self._cleanup_task: asyncio.Task[None] | None = None
async def start_cleanup(self) -> None:
"""Start the background cleanup task."""
if self._cleanup_task is None:
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
async def _cleanup_loop(self) -> None:
"""Background task to clean up expired entries."""
while True:
try:
await asyncio.sleep(self._cleanup_interval)
await self._cleanup_expired()
except asyncio.CancelledError:
break
except Exception:
pass
async def _cleanup_expired(self) -> None:
"""Remove expired entries."""
now = time.time()
async with self._lock:
expired_keys = [
key for key, (_, expires_at) in self._data.items() if expires_at <= now
]
for key in expired_keys:
del self._data[key]
def _evict_if_needed(self) -> None:
"""Evict oldest entries if over max size (must be called with lock held)."""
while len(self._data) > self._max_size:
self._data.popitem(last=False)
async def get(self, key: str) -> dict[str, Any] | None:
"""Get the current state for a key."""
async with self._lock:
if key not in self._data:
return None
value, expires_at = self._data[key]
if expires_at <= time.time():
del self._data[key]
return None
self._data.move_to_end(key)
return value.copy()
async def set(self, key: str, value: dict[str, Any], *, ttl: float) -> None:
"""Set the state for a key with TTL."""
expires_at = time.time() + ttl
async with self._lock:
self._data[key] = (value.copy(), expires_at)
self._data.move_to_end(key)
self._evict_if_needed()
async def delete(self, key: str) -> None:
"""Delete the state for a key."""
async with self._lock:
self._data.pop(key, None)
async def exists(self, key: str) -> bool:
"""Check if a key exists and is not expired."""
async with self._lock:
if key not in self._data:
return False
_, expires_at = self._data[key]
if expires_at <= time.time():
del self._data[key]
return False
return True
async def increment(self, key: str, amount: int = 1) -> int:
"""Atomically increment a counter."""
async with self._lock:
if key in self._data:
value, expires_at = self._data[key]
if expires_at > time.time():
current = int(value.get("count", 0))
new_value = current + amount
value["count"] = new_value
self._data[key] = (value, expires_at)
return new_value
return amount
async def clear(self) -> None:
"""Clear all rate limit data."""
async with self._lock:
self._data.clear()
async def close(self) -> None:
"""Stop cleanup task and clear data."""
if self._cleanup_task is not None:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
self._cleanup_task = None
await self.clear()
def __len__(self) -> int:
"""Return the number of stored entries."""
return len(self._data)

View File

@@ -0,0 +1,232 @@
"""Redis backend for rate limiting - distributed storage for multi-node deployments."""
from __future__ import annotations
import json
from typing import TYPE_CHECKING, Any
from fastapi_traffic.backends.base import Backend
from fastapi_traffic.exceptions import BackendError
if TYPE_CHECKING:
from redis.asyncio import Redis
class RedisBackend(Backend):
"""Redis-based backend for distributed rate limiting."""
__slots__ = ("_client", "_key_prefix", "_owns_client")
def __init__(
self,
client: Redis[bytes],
*,
key_prefix: str = "fastapi_traffic",
) -> None:
"""Initialize the Redis backend.
Args:
client: An async Redis client instance.
key_prefix: Prefix for all rate limit keys.
"""
self._client = client
self._key_prefix = key_prefix
self._owns_client = False
@classmethod
async def from_url(
cls,
url: str = "redis://localhost:6379/0",
*,
key_prefix: str = "fastapi_traffic",
**kwargs: Any,
) -> RedisBackend:
"""Create a RedisBackend from a Redis URL.
Args:
url: Redis connection URL.
key_prefix: Prefix for all rate limit keys.
**kwargs: Additional arguments passed to Redis.from_url().
Returns:
A new RedisBackend instance.
"""
try:
from redis.asyncio import Redis
except ImportError as e:
msg = "redis package is required for RedisBackend. Install with: pip install redis"
raise ImportError(msg) from e
client: Redis[bytes] = Redis.from_url(url, **kwargs)
instance = cls(client, key_prefix=key_prefix)
instance._owns_client = True
return instance
def _make_key(self, key: str) -> str:
"""Create a prefixed key."""
return f"{self._key_prefix}:{key}"
async def get(self, key: str) -> dict[str, Any] | None:
"""Get the current state for a key."""
try:
full_key = self._make_key(key)
data = await self._client.get(full_key)
if data is None:
return None
result: dict[str, Any] = json.loads(data)
return result
except Exception as e:
raise BackendError(f"Failed to get key {key}", original_error=e)
async def set(self, key: str, value: dict[str, Any], *, ttl: float) -> None:
"""Set the state for a key with TTL."""
try:
full_key = self._make_key(key)
data = json.dumps(value)
await self._client.setex(full_key, int(ttl) + 1, data)
except Exception as e:
raise BackendError(f"Failed to set key {key}", original_error=e)
async def delete(self, key: str) -> None:
"""Delete the state for a key."""
try:
full_key = self._make_key(key)
await self._client.delete(full_key)
except Exception as e:
raise BackendError(f"Failed to delete key {key}", original_error=e)
async def exists(self, key: str) -> bool:
"""Check if a key exists."""
try:
full_key = self._make_key(key)
result = await self._client.exists(full_key)
return bool(result)
except Exception as e:
raise BackendError(f"Failed to check key {key}", original_error=e)
async def increment(self, key: str, amount: int = 1) -> int:
"""Atomically increment a counter using Redis INCRBY."""
try:
full_key = self._make_key(key)
result = await self._client.incrby(full_key, amount)
return int(result)
except Exception as e:
raise BackendError(f"Failed to increment key {key}", original_error=e)
async def clear(self) -> None:
"""Clear all rate limit data with this prefix."""
try:
pattern = f"{self._key_prefix}:*"
cursor: int = 0
while True:
cursor, keys = await self._client.scan(cursor, match=pattern, count=100)
if keys:
await self._client.delete(*keys)
if cursor == 0:
break
except Exception as e:
raise BackendError("Failed to clear rate limits", original_error=e)
async def close(self) -> None:
"""Close the Redis connection if we own it."""
if self._owns_client:
await self._client.aclose()
async def ping(self) -> bool:
"""Check if Redis is reachable."""
try:
await self._client.ping()
return True
except Exception:
return False
async def get_stats(self) -> dict[str, Any]:
"""Get statistics about the rate limit storage."""
try:
pattern = f"{self._key_prefix}:*"
cursor: int = 0
count = 0
while True:
cursor, keys = await self._client.scan(cursor, match=pattern, count=100)
count += len(keys)
if cursor == 0:
break
info = await self._client.info("memory")
return {
"total_keys": count,
"used_memory": info.get("used_memory_human", "unknown"),
"key_prefix": self._key_prefix,
}
except Exception as e:
raise BackendError("Failed to get stats", original_error=e)
# Lua scripts for atomic operations
SLIDING_WINDOW_SCRIPT = """
local key = KEYS[1]
local now = tonumber(ARGV[1])
local window_size = tonumber(ARGV[2])
local limit = tonumber(ARGV[3])
local window_start = now - window_size
-- Remove expired entries
redis.call('ZREMRANGEBYSCORE', key, '-inf', window_start)
-- Count current entries
local count = redis.call('ZCARD', key)
if count < limit then
-- Add new entry
redis.call('ZADD', key, now, now .. ':' .. math.random())
redis.call('EXPIRE', key, math.ceil(window_size) + 1)
return {1, limit - count - 1}
else
-- Get oldest entry for retry-after calculation
local oldest = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES')
local retry_after = 0
if #oldest > 0 then
retry_after = oldest[2] + window_size - now
end
return {0, 0, retry_after}
end
"""
TOKEN_BUCKET_SCRIPT = """
local key = KEYS[1]
local now = tonumber(ARGV[1])
local bucket_size = tonumber(ARGV[2])
local refill_rate = tonumber(ARGV[3])
local ttl = tonumber(ARGV[4])
local data = redis.call('GET', key)
local tokens, last_update
if data then
local decoded = cjson.decode(data)
tokens = decoded.tokens
last_update = decoded.last_update
else
tokens = bucket_size
last_update = now
end
-- Refill tokens
local elapsed = now - last_update
tokens = math.min(bucket_size, tokens + elapsed * refill_rate)
local allowed = 0
local retry_after = 0
if tokens >= 1 then
tokens = tokens - 1
allowed = 1
else
retry_after = (1 - tokens) / refill_rate
end
-- Save state
redis.call('SETEX', key, ttl, cjson.encode({tokens = tokens, last_update = now}))
return {allowed, math.floor(tokens), retry_after}
"""

View File

@@ -0,0 +1,298 @@
"""SQLite backend for rate limiting - persistent storage for single-node deployments."""
from __future__ import annotations
import asyncio
import json
import sqlite3
import time
from pathlib import Path
from typing import Any
from fastapi_traffic.backends.base import Backend
from fastapi_traffic.exceptions import BackendError
class SQLiteBackend(Backend):
"""SQLite-based backend with connection pooling and async support."""
__slots__ = (
"_db_path",
"_connection",
"_lock",
"_cleanup_interval",
"_cleanup_task",
"_pool_size",
"_connections",
)
def __init__(
self,
db_path: str | Path = ":memory:",
*,
cleanup_interval: float = 300.0,
pool_size: int = 5,
) -> None:
"""Initialize the SQLite backend.
Args:
db_path: Path to SQLite database file or ":memory:" for in-memory.
cleanup_interval: Interval in seconds for cleaning expired entries.
pool_size: Number of connections in the pool.
"""
self._db_path = str(db_path)
self._connection: sqlite3.Connection | None = None
self._lock = asyncio.Lock()
self._cleanup_interval = cleanup_interval
self._cleanup_task: asyncio.Task[None] | None = None
self._pool_size = pool_size
self._connections: list[sqlite3.Connection] = []
async def initialize(self) -> None:
"""Initialize the database and create tables."""
await self._ensure_connection()
await self._create_tables()
if self._cleanup_task is None:
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
async def _ensure_connection(self) -> sqlite3.Connection:
"""Ensure a database connection exists."""
if self._connection is None:
loop = asyncio.get_event_loop()
self._connection = await loop.run_in_executor(
None, self._create_connection
)
assert self._connection is not None
return self._connection
def _create_connection(self) -> sqlite3.Connection:
"""Create a new SQLite connection with optimized settings."""
conn = sqlite3.connect(
self._db_path,
check_same_thread=False,
isolation_level=None,
)
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA synchronous=NORMAL")
conn.execute("PRAGMA cache_size=10000")
conn.execute("PRAGMA temp_store=MEMORY")
conn.row_factory = sqlite3.Row
return conn
async def _create_tables(self) -> None:
"""Create the rate limit tables."""
conn = await self._ensure_connection()
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, self._create_tables_sync, conn)
def _create_tables_sync(self, conn: sqlite3.Connection) -> None:
"""Synchronously create tables."""
conn.execute("""
CREATE TABLE IF NOT EXISTS rate_limits (
key TEXT PRIMARY KEY,
data TEXT NOT NULL,
expires_at REAL NOT NULL
)
""")
conn.execute("""
CREATE INDEX IF NOT EXISTS idx_expires_at ON rate_limits(expires_at)
""")
async def _cleanup_loop(self) -> None:
"""Background task to clean up expired entries."""
while True:
try:
await asyncio.sleep(self._cleanup_interval)
await self._cleanup_expired()
except asyncio.CancelledError:
break
except Exception:
pass
async def _cleanup_expired(self) -> None:
"""Remove expired entries."""
try:
conn = await self._ensure_connection()
loop = asyncio.get_event_loop()
await loop.run_in_executor(
None,
lambda: conn.execute(
"DELETE FROM rate_limits WHERE expires_at <= ?", (time.time(),)
),
)
except Exception as e:
raise BackendError("Failed to cleanup expired entries", original_error=e)
async def get(self, key: str) -> dict[str, Any] | None:
"""Get the current state for a key."""
try:
conn = await self._ensure_connection()
loop = asyncio.get_event_loop()
def _get() -> dict[str, Any] | None:
cursor = conn.execute(
"SELECT data, expires_at FROM rate_limits WHERE key = ?",
(key,),
)
row = cursor.fetchone()
if row is None:
return None
expires_at = row["expires_at"]
if expires_at <= time.time():
conn.execute("DELETE FROM rate_limits WHERE key = ?", (key,))
return None
data: dict[str, Any] = json.loads(row["data"])
return data
return await loop.run_in_executor(None, _get)
except Exception as e:
raise BackendError(f"Failed to get key {key}", original_error=e)
async def set(self, key: str, value: dict[str, Any], *, ttl: float) -> None:
"""Set the state for a key with TTL."""
try:
conn = await self._ensure_connection()
loop = asyncio.get_event_loop()
expires_at = time.time() + ttl
data_json = json.dumps(value)
def _set() -> None:
conn.execute(
"""
INSERT OR REPLACE INTO rate_limits (key, data, expires_at)
VALUES (?, ?, ?)
""",
(key, data_json, expires_at),
)
await loop.run_in_executor(None, _set)
except Exception as e:
raise BackendError(f"Failed to set key {key}", original_error=e)
async def delete(self, key: str) -> None:
"""Delete the state for a key."""
try:
conn = await self._ensure_connection()
loop = asyncio.get_event_loop()
await loop.run_in_executor(
None,
lambda: conn.execute("DELETE FROM rate_limits WHERE key = ?", (key,)),
)
except Exception as e:
raise BackendError(f"Failed to delete key {key}", original_error=e)
async def exists(self, key: str) -> bool:
"""Check if a key exists and is not expired."""
try:
conn = await self._ensure_connection()
loop = asyncio.get_event_loop()
def _exists() -> bool:
cursor = conn.execute(
"SELECT 1 FROM rate_limits WHERE key = ? AND expires_at > ?",
(key, time.time()),
)
return cursor.fetchone() is not None
return await loop.run_in_executor(None, _exists)
except Exception as e:
raise BackendError(f"Failed to check key {key}", original_error=e)
async def increment(self, key: str, amount: int = 1) -> int:
"""Atomically increment a counter."""
async with self._lock:
try:
conn = await self._ensure_connection()
loop = asyncio.get_event_loop()
def _increment() -> int:
cursor = conn.execute(
"SELECT data, expires_at FROM rate_limits WHERE key = ?",
(key,),
)
row = cursor.fetchone()
if row is None or row["expires_at"] <= time.time():
return amount
data: dict[str, Any] = json.loads(row["data"])
current = int(data.get("count", 0))
new_value = current + amount
data["count"] = new_value
conn.execute(
"UPDATE rate_limits SET data = ? WHERE key = ?",
(json.dumps(data), key),
)
return new_value
return await loop.run_in_executor(None, _increment)
except Exception as e:
raise BackendError(f"Failed to increment key {key}", original_error=e)
async def clear(self) -> None:
"""Clear all rate limit data."""
try:
conn = await self._ensure_connection()
loop = asyncio.get_event_loop()
await loop.run_in_executor(
None, lambda: conn.execute("DELETE FROM rate_limits")
)
except Exception as e:
raise BackendError("Failed to clear rate limits", original_error=e)
async def close(self) -> None:
"""Close the database connection."""
if self._cleanup_task is not None:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
self._cleanup_task = None
if self._connection is not None:
self._connection.close()
self._connection = None
for conn in self._connections:
conn.close()
self._connections.clear()
async def vacuum(self) -> None:
"""Optimize the database by running VACUUM."""
try:
conn = await self._ensure_connection()
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, lambda: conn.execute("VACUUM"))
except Exception as e:
raise BackendError("Failed to vacuum database", original_error=e)
async def get_stats(self) -> dict[str, Any]:
"""Get statistics about the rate limit storage."""
try:
conn = await self._ensure_connection()
loop = asyncio.get_event_loop()
def _stats() -> dict[str, Any]:
cursor = conn.execute("SELECT COUNT(*) as total FROM rate_limits")
total = cursor.fetchone()["total"]
cursor = conn.execute(
"SELECT COUNT(*) as active FROM rate_limits WHERE expires_at > ?",
(time.time(),),
)
active = cursor.fetchone()["active"]
return {
"total_entries": total,
"active_entries": active,
"expired_entries": total - active,
"db_path": self._db_path,
}
return await loop.run_in_executor(None, _stats)
except Exception as e:
raise BackendError("Failed to get stats", original_error=e)