"""Redis backend for rate limiting - distributed storage for multi-node deployments.""" from __future__ import annotations import json from typing import TYPE_CHECKING, Any from fastapi_traffic.backends.base import Backend from fastapi_traffic.exceptions import BackendError if TYPE_CHECKING: from redis.asyncio import Redis class RedisBackend(Backend): """Redis-based backend for distributed rate limiting.""" __slots__ = ("_client", "_key_prefix", "_owns_client") def __init__( self, client: Redis, *, key_prefix: str = "fastapi_traffic", ) -> None: """Initialize the Redis backend. Args: client: An async Redis client instance. key_prefix: Prefix for all rate limit keys. """ self._client = client self._key_prefix = key_prefix self._owns_client = False @classmethod async def from_url( cls, url: str = "redis://localhost:6379/0", *, key_prefix: str = "fastapi_traffic", **kwargs: Any, ) -> RedisBackend: """Create a RedisBackend from a Redis URL. Args: url: Redis connection URL. key_prefix: Prefix for all rate limit keys. **kwargs: Additional arguments passed to Redis.from_url(). Returns: A new RedisBackend instance. """ try: from redis.asyncio import Redis except ImportError as e: msg = "redis package is required for RedisBackend. Install with: pip install redis" raise ImportError(msg) from e client: Redis = Redis.from_url(url, **kwargs) # pyright: ignore[reportUnknownMemberType] # fmt: skip # note: No type stubs for redis-py, so we ignore the type errors instance = cls(client, key_prefix=key_prefix) instance._owns_client = True return instance def _make_key(self, key: str) -> str: """Create a prefixed key.""" return f"{self._key_prefix}:{key}" async def get(self, key: str) -> dict[str, Any] | None: """Get the current state for a key.""" try: full_key = self._make_key(key) data = await self._client.get(full_key) if data is None: return None result: dict[str, Any] = json.loads(data) return result except Exception as e: raise BackendError(f"Failed to get key {key}", original_error=e) async def set(self, key: str, value: dict[str, Any], *, ttl: float) -> None: """Set the state for a key with TTL.""" try: full_key = self._make_key(key) data = json.dumps(value) await self._client.setex(full_key, int(ttl) + 1, data) except Exception as e: raise BackendError(f"Failed to set key {key}", original_error=e) async def delete(self, key: str) -> None: """Delete the state for a key.""" try: full_key = self._make_key(key) await self._client.delete(full_key) except Exception as e: raise BackendError(f"Failed to delete key {key}", original_error=e) async def exists(self, key: str) -> bool: """Check if a key exists.""" try: full_key = self._make_key(key) result = await self._client.exists(full_key) return bool(result) except Exception as e: raise BackendError(f"Failed to check key {key}", original_error=e) async def increment(self, key: str, amount: int = 1) -> int: """Atomically increment a counter using Redis INCRBY.""" try: full_key = self._make_key(key) result = await self._client.incrby(full_key, amount) return int(result) except Exception as e: raise BackendError(f"Failed to increment key {key}", original_error=e) async def clear(self) -> None: """Clear all rate limit data with this prefix.""" try: pattern = f"{self._key_prefix}:*" cursor: int = 0 while True: cursor, keys = ( await self._client.scan( # pyright: ignore[reportUnknownMemberType] cursor, match=pattern, count=100 ) ) if keys: await self._client.delete(*keys) if cursor == 0: break except Exception as e: raise BackendError("Failed to clear rate limits", original_error=e) async def close(self) -> None: """Close the Redis connection if we own it.""" if self._owns_client: await self._client.aclose() async def ping(self) -> bool: """Check if Redis is reachable.""" return await self._client.ping() # pyright: ignore[reportUnknownMemberType, reportGeneralTypeIssues, reportUnknownVariableType, reportReturnType] # fmt: skip async def get_stats(self) -> dict[str, Any]: """Get statistics about the rate limit storage.""" try: pattern = f"{self._key_prefix}:*" cursor: int = 0 count = 0 while True: cursor, keys = ( await self._client.scan( # pyright: ignore[reportUnknownMemberType] cursor, match=pattern, count=100 ) ) count += len(keys) if cursor == 0: break info: dict[str, Any] = ( await self._client.info( # pyright: ignore[reportUnknownMemberType] "memory" ) ) return { "total_keys": count, "used_memory": info.get("used_memory_human", "unknown"), "key_prefix": self._key_prefix, } except Exception as e: raise BackendError("Failed to get stats", original_error=e) # Lua scripts for atomic operations SLIDING_WINDOW_SCRIPT = """ local key = KEYS[1] local now = tonumber(ARGV[1]) local window_size = tonumber(ARGV[2]) local limit = tonumber(ARGV[3]) local window_start = now - window_size -- Remove expired entries redis.call('ZREMRANGEBYSCORE', key, '-inf', window_start) -- Count current entries local count = redis.call('ZCARD', key) if count < limit then -- Add new entry redis.call('ZADD', key, now, now .. ':' .. math.random()) redis.call('EXPIRE', key, math.ceil(window_size) + 1) return {1, limit - count - 1} else -- Get oldest entry for retry-after calculation local oldest = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES') local retry_after = 0 if #oldest > 0 then retry_after = oldest[2] + window_size - now end return {0, 0, retry_after} end """ TOKEN_BUCKET_SCRIPT = """ local key = KEYS[1] local now = tonumber(ARGV[1]) local bucket_size = tonumber(ARGV[2]) local refill_rate = tonumber(ARGV[3]) local ttl = tonumber(ARGV[4]) local data = redis.call('GET', key) local tokens, last_update if data then local decoded = cjson.decode(data) tokens = decoded.tokens last_update = decoded.last_update else tokens = bucket_size last_update = now end -- Refill tokens local elapsed = now - last_update tokens = math.min(bucket_size, tokens + elapsed * refill_rate) local allowed = 0 local retry_after = 0 if tokens >= 1 then tokens = tokens - 1 allowed = 1 else retry_after = (1 - tokens) / refill_rate end -- Save state redis.call('SETEX', key, ttl, cjson.encode({tokens = tokens, last_update = now})) return {allowed, math.floor(tokens), retry_after} """