"""In-memory backend for rate limiting - suitable for single-process applications.""" from __future__ import annotations import asyncio import contextlib import time from collections import OrderedDict from typing import Any from fastapi_traffic.backends.base import Backend class MemoryBackend(Backend): """Thread-safe in-memory backend with LRU eviction and TTL support.""" __slots__ = ("_cleanup_interval", "_cleanup_task", "_data", "_lock", "_max_size") def __init__( self, *, max_size: int = 10000, cleanup_interval: float = 60.0, ) -> None: """Initialize the memory backend. Args: max_size: Maximum number of entries to store (LRU eviction). cleanup_interval: Interval in seconds for cleaning expired entries. """ self._data: OrderedDict[str, tuple[dict[str, Any], float]] = OrderedDict() self._lock = asyncio.Lock() self._max_size = max_size self._cleanup_interval = cleanup_interval self._cleanup_task: asyncio.Task[None] | None = None async def start_cleanup(self) -> None: """Start the background cleanup task.""" if self._cleanup_task is None: self._cleanup_task = asyncio.create_task(self._cleanup_loop()) async def _cleanup_loop(self) -> None: """Background task to clean up expired entries.""" while True: try: await asyncio.sleep(self._cleanup_interval) await self._cleanup_expired() except asyncio.CancelledError: break except Exception: pass async def _cleanup_expired(self) -> None: """Remove expired entries.""" now = time.time() async with self._lock: expired_keys = [ key for key, (_, expires_at) in self._data.items() if expires_at <= now ] for key in expired_keys: del self._data[key] def _evict_if_needed(self) -> None: """Evict oldest entries if over max size (must be called with lock held).""" while len(self._data) > self._max_size: self._data.popitem(last=False) async def get(self, key: str) -> dict[str, Any] | None: """Get the current state for a key.""" async with self._lock: if key not in self._data: return None value, expires_at = self._data[key] if expires_at <= time.time(): del self._data[key] return None self._data.move_to_end(key) return value.copy() async def set(self, key: str, value: dict[str, Any], *, ttl: float) -> None: """Set the state for a key with TTL.""" expires_at = time.time() + ttl async with self._lock: self._data[key] = (value.copy(), expires_at) self._data.move_to_end(key) self._evict_if_needed() async def delete(self, key: str) -> None: """Delete the state for a key.""" async with self._lock: self._data.pop(key, None) async def exists(self, key: str) -> bool: """Check if a key exists and is not expired.""" async with self._lock: if key not in self._data: return False _, expires_at = self._data[key] if expires_at <= time.time(): del self._data[key] return False return True async def increment(self, key: str, amount: int = 1) -> int: """Atomically increment a counter.""" async with self._lock: if key in self._data: value, expires_at = self._data[key] if expires_at > time.time(): current = int(value.get("count", 0)) new_value = current + amount value["count"] = new_value self._data[key] = (value, expires_at) return new_value return amount async def clear(self) -> None: """Clear all rate limit data.""" async with self._lock: self._data.clear() async def close(self) -> None: """Stop cleanup task and clear data.""" if self._cleanup_task is not None: self._cleanup_task.cancel() with contextlib.suppress(asyncio.CancelledError): await self._cleanup_task self._cleanup_task = None await self.clear() async def ping(self) -> bool: """Check if the backend is available. Always returns True for memory backend.""" return True async def get_stats(self) -> dict[str, Any]: """Get statistics about the rate limit storage.""" async with self._lock: return { "total_keys": len(self._data), "max_size": self._max_size, "backend": "memory", } def __len__(self) -> int: """Return the number of stored entries.""" return len(self._data)