- Core rate limiting with multiple algorithms (sliding window, token bucket, etc.) - SQLite and memory backends - Decorator and dependency injection patterns - Middleware support - Example usage files
233 lines
7.0 KiB
Python
233 lines
7.0 KiB
Python
"""Redis backend for rate limiting - distributed storage for multi-node deployments."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
from fastapi_traffic.backends.base import Backend
|
|
from fastapi_traffic.exceptions import BackendError
|
|
|
|
if TYPE_CHECKING:
|
|
from redis.asyncio import Redis
|
|
|
|
|
|
class RedisBackend(Backend):
|
|
"""Redis-based backend for distributed rate limiting."""
|
|
|
|
__slots__ = ("_client", "_key_prefix", "_owns_client")
|
|
|
|
def __init__(
|
|
self,
|
|
client: Redis[bytes],
|
|
*,
|
|
key_prefix: str = "fastapi_traffic",
|
|
) -> None:
|
|
"""Initialize the Redis backend.
|
|
|
|
Args:
|
|
client: An async Redis client instance.
|
|
key_prefix: Prefix for all rate limit keys.
|
|
"""
|
|
self._client = client
|
|
self._key_prefix = key_prefix
|
|
self._owns_client = False
|
|
|
|
@classmethod
|
|
async def from_url(
|
|
cls,
|
|
url: str = "redis://localhost:6379/0",
|
|
*,
|
|
key_prefix: str = "fastapi_traffic",
|
|
**kwargs: Any,
|
|
) -> RedisBackend:
|
|
"""Create a RedisBackend from a Redis URL.
|
|
|
|
Args:
|
|
url: Redis connection URL.
|
|
key_prefix: Prefix for all rate limit keys.
|
|
**kwargs: Additional arguments passed to Redis.from_url().
|
|
|
|
Returns:
|
|
A new RedisBackend instance.
|
|
"""
|
|
try:
|
|
from redis.asyncio import Redis
|
|
except ImportError as e:
|
|
msg = "redis package is required for RedisBackend. Install with: pip install redis"
|
|
raise ImportError(msg) from e
|
|
|
|
client: Redis[bytes] = Redis.from_url(url, **kwargs)
|
|
instance = cls(client, key_prefix=key_prefix)
|
|
instance._owns_client = True
|
|
return instance
|
|
|
|
def _make_key(self, key: str) -> str:
|
|
"""Create a prefixed key."""
|
|
return f"{self._key_prefix}:{key}"
|
|
|
|
async def get(self, key: str) -> dict[str, Any] | None:
|
|
"""Get the current state for a key."""
|
|
try:
|
|
full_key = self._make_key(key)
|
|
data = await self._client.get(full_key)
|
|
if data is None:
|
|
return None
|
|
result: dict[str, Any] = json.loads(data)
|
|
return result
|
|
except Exception as e:
|
|
raise BackendError(f"Failed to get key {key}", original_error=e)
|
|
|
|
async def set(self, key: str, value: dict[str, Any], *, ttl: float) -> None:
|
|
"""Set the state for a key with TTL."""
|
|
try:
|
|
full_key = self._make_key(key)
|
|
data = json.dumps(value)
|
|
await self._client.setex(full_key, int(ttl) + 1, data)
|
|
except Exception as e:
|
|
raise BackendError(f"Failed to set key {key}", original_error=e)
|
|
|
|
async def delete(self, key: str) -> None:
|
|
"""Delete the state for a key."""
|
|
try:
|
|
full_key = self._make_key(key)
|
|
await self._client.delete(full_key)
|
|
except Exception as e:
|
|
raise BackendError(f"Failed to delete key {key}", original_error=e)
|
|
|
|
async def exists(self, key: str) -> bool:
|
|
"""Check if a key exists."""
|
|
try:
|
|
full_key = self._make_key(key)
|
|
result = await self._client.exists(full_key)
|
|
return bool(result)
|
|
except Exception as e:
|
|
raise BackendError(f"Failed to check key {key}", original_error=e)
|
|
|
|
async def increment(self, key: str, amount: int = 1) -> int:
|
|
"""Atomically increment a counter using Redis INCRBY."""
|
|
try:
|
|
full_key = self._make_key(key)
|
|
result = await self._client.incrby(full_key, amount)
|
|
return int(result)
|
|
except Exception as e:
|
|
raise BackendError(f"Failed to increment key {key}", original_error=e)
|
|
|
|
async def clear(self) -> None:
|
|
"""Clear all rate limit data with this prefix."""
|
|
try:
|
|
pattern = f"{self._key_prefix}:*"
|
|
cursor: int = 0
|
|
while True:
|
|
cursor, keys = await self._client.scan(cursor, match=pattern, count=100)
|
|
if keys:
|
|
await self._client.delete(*keys)
|
|
if cursor == 0:
|
|
break
|
|
except Exception as e:
|
|
raise BackendError("Failed to clear rate limits", original_error=e)
|
|
|
|
async def close(self) -> None:
|
|
"""Close the Redis connection if we own it."""
|
|
if self._owns_client:
|
|
await self._client.aclose()
|
|
|
|
async def ping(self) -> bool:
|
|
"""Check if Redis is reachable."""
|
|
try:
|
|
await self._client.ping()
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
async def get_stats(self) -> dict[str, Any]:
|
|
"""Get statistics about the rate limit storage."""
|
|
try:
|
|
pattern = f"{self._key_prefix}:*"
|
|
cursor: int = 0
|
|
count = 0
|
|
while True:
|
|
cursor, keys = await self._client.scan(cursor, match=pattern, count=100)
|
|
count += len(keys)
|
|
if cursor == 0:
|
|
break
|
|
|
|
info = await self._client.info("memory")
|
|
return {
|
|
"total_keys": count,
|
|
"used_memory": info.get("used_memory_human", "unknown"),
|
|
"key_prefix": self._key_prefix,
|
|
}
|
|
except Exception as e:
|
|
raise BackendError("Failed to get stats", original_error=e)
|
|
|
|
|
|
# Lua scripts for atomic operations
|
|
SLIDING_WINDOW_SCRIPT = """
|
|
local key = KEYS[1]
|
|
local now = tonumber(ARGV[1])
|
|
local window_size = tonumber(ARGV[2])
|
|
local limit = tonumber(ARGV[3])
|
|
local window_start = now - window_size
|
|
|
|
-- Remove expired entries
|
|
redis.call('ZREMRANGEBYSCORE', key, '-inf', window_start)
|
|
|
|
-- Count current entries
|
|
local count = redis.call('ZCARD', key)
|
|
|
|
if count < limit then
|
|
-- Add new entry
|
|
redis.call('ZADD', key, now, now .. ':' .. math.random())
|
|
redis.call('EXPIRE', key, math.ceil(window_size) + 1)
|
|
return {1, limit - count - 1}
|
|
else
|
|
-- Get oldest entry for retry-after calculation
|
|
local oldest = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES')
|
|
local retry_after = 0
|
|
if #oldest > 0 then
|
|
retry_after = oldest[2] + window_size - now
|
|
end
|
|
return {0, 0, retry_after}
|
|
end
|
|
"""
|
|
|
|
TOKEN_BUCKET_SCRIPT = """
|
|
local key = KEYS[1]
|
|
local now = tonumber(ARGV[1])
|
|
local bucket_size = tonumber(ARGV[2])
|
|
local refill_rate = tonumber(ARGV[3])
|
|
local ttl = tonumber(ARGV[4])
|
|
|
|
local data = redis.call('GET', key)
|
|
local tokens, last_update
|
|
|
|
if data then
|
|
local decoded = cjson.decode(data)
|
|
tokens = decoded.tokens
|
|
last_update = decoded.last_update
|
|
else
|
|
tokens = bucket_size
|
|
last_update = now
|
|
end
|
|
|
|
-- Refill tokens
|
|
local elapsed = now - last_update
|
|
tokens = math.min(bucket_size, tokens + elapsed * refill_rate)
|
|
|
|
local allowed = 0
|
|
local retry_after = 0
|
|
|
|
if tokens >= 1 then
|
|
tokens = tokens - 1
|
|
allowed = 1
|
|
else
|
|
retry_after = (1 - tokens) / refill_rate
|
|
end
|
|
|
|
-- Save state
|
|
redis.call('SETEX', key, ttl, cjson.encode({tokens = tokens, last_update = now}))
|
|
|
|
return {allowed, math.floor(tokens), retry_after}
|
|
"""
|