"""SQLite backend for rate limiting - persistent storage for single-node deployments.""" from __future__ import annotations import asyncio import json import sqlite3 import time from pathlib import Path from typing import Any from fastapi_traffic.backends.base import Backend from fastapi_traffic.exceptions import BackendError class SQLiteBackend(Backend): """SQLite-based backend with connection pooling and async support.""" __slots__ = ( "_db_path", "_connection", "_lock", "_cleanup_interval", "_cleanup_task", "_pool_size", "_connections", ) def __init__( self, db_path: str | Path = ":memory:", *, cleanup_interval: float = 300.0, pool_size: int = 5, ) -> None: """Initialize the SQLite backend. Args: db_path: Path to SQLite database file or ":memory:" for in-memory. cleanup_interval: Interval in seconds for cleaning expired entries. pool_size: Number of connections in the pool. """ self._db_path = str(db_path) self._connection: sqlite3.Connection | None = None self._lock = asyncio.Lock() self._cleanup_interval = cleanup_interval self._cleanup_task: asyncio.Task[None] | None = None self._pool_size = pool_size self._connections: list[sqlite3.Connection] = [] async def initialize(self) -> None: """Initialize the database and create tables.""" await self._ensure_connection() await self._create_tables() if self._cleanup_task is None: self._cleanup_task = asyncio.create_task(self._cleanup_loop()) async def _ensure_connection(self) -> sqlite3.Connection: """Ensure a database connection exists.""" if self._connection is None: loop = asyncio.get_event_loop() self._connection = await loop.run_in_executor( None, self._create_connection ) assert self._connection is not None return self._connection def _create_connection(self) -> sqlite3.Connection: """Create a new SQLite connection with optimized settings.""" conn = sqlite3.connect( self._db_path, check_same_thread=False, isolation_level=None, ) conn.execute("PRAGMA journal_mode=WAL") conn.execute("PRAGMA synchronous=NORMAL") conn.execute("PRAGMA cache_size=10000") conn.execute("PRAGMA temp_store=MEMORY") conn.row_factory = sqlite3.Row return conn async def _create_tables(self) -> None: """Create the rate limit tables.""" conn = await self._ensure_connection() loop = asyncio.get_event_loop() await loop.run_in_executor(None, self._create_tables_sync, conn) def _create_tables_sync(self, conn: sqlite3.Connection) -> None: """Synchronously create tables.""" conn.execute(""" CREATE TABLE IF NOT EXISTS rate_limits ( key TEXT PRIMARY KEY, data TEXT NOT NULL, expires_at REAL NOT NULL ) """) conn.execute(""" CREATE INDEX IF NOT EXISTS idx_expires_at ON rate_limits(expires_at) """) async def _cleanup_loop(self) -> None: """Background task to clean up expired entries.""" while True: try: await asyncio.sleep(self._cleanup_interval) await self._cleanup_expired() except asyncio.CancelledError: break except Exception: pass async def _cleanup_expired(self) -> None: """Remove expired entries.""" try: conn = await self._ensure_connection() loop = asyncio.get_event_loop() await loop.run_in_executor( None, lambda: conn.execute( "DELETE FROM rate_limits WHERE expires_at <= ?", (time.time(),) ), ) except Exception as e: raise BackendError("Failed to cleanup expired entries", original_error=e) async def get(self, key: str) -> dict[str, Any] | None: """Get the current state for a key.""" try: conn = await self._ensure_connection() loop = asyncio.get_event_loop() def _get() -> dict[str, Any] | None: cursor = conn.execute( "SELECT data, expires_at FROM rate_limits WHERE key = ?", (key,), ) row = cursor.fetchone() if row is None: return None expires_at = row["expires_at"] if expires_at <= time.time(): conn.execute("DELETE FROM rate_limits WHERE key = ?", (key,)) return None data: dict[str, Any] = json.loads(row["data"]) return data return await loop.run_in_executor(None, _get) except Exception as e: raise BackendError(f"Failed to get key {key}", original_error=e) async def set(self, key: str, value: dict[str, Any], *, ttl: float) -> None: """Set the state for a key with TTL.""" try: conn = await self._ensure_connection() loop = asyncio.get_event_loop() expires_at = time.time() + ttl data_json = json.dumps(value) def _set() -> None: conn.execute( """ INSERT OR REPLACE INTO rate_limits (key, data, expires_at) VALUES (?, ?, ?) """, (key, data_json, expires_at), ) await loop.run_in_executor(None, _set) except Exception as e: raise BackendError(f"Failed to set key {key}", original_error=e) async def delete(self, key: str) -> None: """Delete the state for a key.""" try: conn = await self._ensure_connection() loop = asyncio.get_event_loop() await loop.run_in_executor( None, lambda: conn.execute("DELETE FROM rate_limits WHERE key = ?", (key,)), ) except Exception as e: raise BackendError(f"Failed to delete key {key}", original_error=e) async def exists(self, key: str) -> bool: """Check if a key exists and is not expired.""" try: conn = await self._ensure_connection() loop = asyncio.get_event_loop() def _exists() -> bool: cursor = conn.execute( "SELECT 1 FROM rate_limits WHERE key = ? AND expires_at > ?", (key, time.time()), ) return cursor.fetchone() is not None return await loop.run_in_executor(None, _exists) except Exception as e: raise BackendError(f"Failed to check key {key}", original_error=e) async def increment(self, key: str, amount: int = 1) -> int: """Atomically increment a counter.""" async with self._lock: try: conn = await self._ensure_connection() loop = asyncio.get_event_loop() def _increment() -> int: cursor = conn.execute( "SELECT data, expires_at FROM rate_limits WHERE key = ?", (key,), ) row = cursor.fetchone() if row is None or row["expires_at"] <= time.time(): return amount data: dict[str, Any] = json.loads(row["data"]) current = int(data.get("count", 0)) new_value = current + amount data["count"] = new_value conn.execute( "UPDATE rate_limits SET data = ? WHERE key = ?", (json.dumps(data), key), ) return new_value return await loop.run_in_executor(None, _increment) except Exception as e: raise BackendError(f"Failed to increment key {key}", original_error=e) async def clear(self) -> None: """Clear all rate limit data.""" try: conn = await self._ensure_connection() loop = asyncio.get_event_loop() await loop.run_in_executor( None, lambda: conn.execute("DELETE FROM rate_limits") ) except Exception as e: raise BackendError("Failed to clear rate limits", original_error=e) async def close(self) -> None: """Close the database connection.""" if self._cleanup_task is not None: self._cleanup_task.cancel() try: await self._cleanup_task except asyncio.CancelledError: pass self._cleanup_task = None if self._connection is not None: self._connection.close() self._connection = None for conn in self._connections: conn.close() self._connections.clear() async def vacuum(self) -> None: """Optimize the database by running VACUUM.""" try: conn = await self._ensure_connection() loop = asyncio.get_event_loop() await loop.run_in_executor(None, lambda: conn.execute("VACUUM")) except Exception as e: raise BackendError("Failed to vacuum database", original_error=e) async def get_stats(self) -> dict[str, Any]: """Get statistics about the rate limit storage.""" try: conn = await self._ensure_connection() loop = asyncio.get_event_loop() def _stats() -> dict[str, Any]: cursor = conn.execute("SELECT COUNT(*) as total FROM rate_limits") total = cursor.fetchone()["total"] cursor = conn.execute( "SELECT COUNT(*) as active FROM rate_limits WHERE expires_at > ?", (time.time(),), ) active = cursor.fetchone()["active"] return { "total_entries": total, "active_entries": active, "expired_entries": total - active, "db_path": self._db_path, } return await loop.run_in_executor(None, _stats) except Exception as e: raise BackendError("Failed to get stats", original_error=e)