Initial commit: fastapi-traffic rate limiting library
- Core rate limiting with multiple algorithms (sliding window, token bucket, etc.) - SQLite and memory backends - Decorator and dependency injection patterns - Middleware support - Example usage files
This commit is contained in:
139
fastapi_traffic/backends/memory.py
Normal file
139
fastapi_traffic/backends/memory.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""In-memory backend for rate limiting - suitable for single-process applications."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from typing import Any
|
||||
|
||||
from fastapi_traffic.backends.base import Backend
|
||||
|
||||
|
||||
class MemoryBackend(Backend):
|
||||
"""Thread-safe in-memory backend with LRU eviction and TTL support."""
|
||||
|
||||
__slots__ = ("_data", "_lock", "_max_size", "_cleanup_interval", "_cleanup_task")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
max_size: int = 10000,
|
||||
cleanup_interval: float = 60.0,
|
||||
) -> None:
|
||||
"""Initialize the memory backend.
|
||||
|
||||
Args:
|
||||
max_size: Maximum number of entries to store (LRU eviction).
|
||||
cleanup_interval: Interval in seconds for cleaning expired entries.
|
||||
"""
|
||||
self._data: OrderedDict[str, tuple[dict[str, Any], float]] = OrderedDict()
|
||||
self._lock = asyncio.Lock()
|
||||
self._max_size = max_size
|
||||
self._cleanup_interval = cleanup_interval
|
||||
self._cleanup_task: asyncio.Task[None] | None = None
|
||||
|
||||
async def start_cleanup(self) -> None:
|
||||
"""Start the background cleanup task."""
|
||||
if self._cleanup_task is None:
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||
|
||||
async def _cleanup_loop(self) -> None:
|
||||
"""Background task to clean up expired entries."""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(self._cleanup_interval)
|
||||
await self._cleanup_expired()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _cleanup_expired(self) -> None:
|
||||
"""Remove expired entries."""
|
||||
now = time.time()
|
||||
async with self._lock:
|
||||
expired_keys = [
|
||||
key for key, (_, expires_at) in self._data.items() if expires_at <= now
|
||||
]
|
||||
for key in expired_keys:
|
||||
del self._data[key]
|
||||
|
||||
def _evict_if_needed(self) -> None:
|
||||
"""Evict oldest entries if over max size (must be called with lock held)."""
|
||||
while len(self._data) > self._max_size:
|
||||
self._data.popitem(last=False)
|
||||
|
||||
async def get(self, key: str) -> dict[str, Any] | None:
|
||||
"""Get the current state for a key."""
|
||||
async with self._lock:
|
||||
if key not in self._data:
|
||||
return None
|
||||
|
||||
value, expires_at = self._data[key]
|
||||
if expires_at <= time.time():
|
||||
del self._data[key]
|
||||
return None
|
||||
|
||||
self._data.move_to_end(key)
|
||||
return value.copy()
|
||||
|
||||
async def set(self, key: str, value: dict[str, Any], *, ttl: float) -> None:
|
||||
"""Set the state for a key with TTL."""
|
||||
expires_at = time.time() + ttl
|
||||
async with self._lock:
|
||||
self._data[key] = (value.copy(), expires_at)
|
||||
self._data.move_to_end(key)
|
||||
self._evict_if_needed()
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
"""Delete the state for a key."""
|
||||
async with self._lock:
|
||||
self._data.pop(key, None)
|
||||
|
||||
async def exists(self, key: str) -> bool:
|
||||
"""Check if a key exists and is not expired."""
|
||||
async with self._lock:
|
||||
if key not in self._data:
|
||||
return False
|
||||
|
||||
_, expires_at = self._data[key]
|
||||
if expires_at <= time.time():
|
||||
del self._data[key]
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def increment(self, key: str, amount: int = 1) -> int:
|
||||
"""Atomically increment a counter."""
|
||||
async with self._lock:
|
||||
if key in self._data:
|
||||
value, expires_at = self._data[key]
|
||||
if expires_at > time.time():
|
||||
current = int(value.get("count", 0))
|
||||
new_value = current + amount
|
||||
value["count"] = new_value
|
||||
self._data[key] = (value, expires_at)
|
||||
return new_value
|
||||
|
||||
return amount
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear all rate limit data."""
|
||||
async with self._lock:
|
||||
self._data.clear()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Stop cleanup task and clear data."""
|
||||
if self._cleanup_task is not None:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._cleanup_task = None
|
||||
await self.clear()
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of stored entries."""
|
||||
return len(self._data)
|
||||
Reference in New Issue
Block a user