refactor: use contextlib.suppress and sort __slots__ in backends
This commit is contained in:
@@ -71,6 +71,7 @@ class Backend(ABC):
|
|||||||
"""Clear all rate limit data."""
|
"""Clear all rate limit data."""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
"""Close the backend connection."""
|
"""Close the backend connection."""
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import contextlib
|
||||||
import time
|
import time
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -13,7 +14,7 @@ from fastapi_traffic.backends.base import Backend
|
|||||||
class MemoryBackend(Backend):
|
class MemoryBackend(Backend):
|
||||||
"""Thread-safe in-memory backend with LRU eviction and TTL support."""
|
"""Thread-safe in-memory backend with LRU eviction and TTL support."""
|
||||||
|
|
||||||
__slots__ = ("_data", "_lock", "_max_size", "_cleanup_interval", "_cleanup_task")
|
__slots__ = ("_cleanup_interval", "_cleanup_task", "_data", "_lock", "_max_size")
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -127,10 +128,8 @@ class MemoryBackend(Backend):
|
|||||||
"""Stop cleanup task and clear data."""
|
"""Stop cleanup task and clear data."""
|
||||||
if self._cleanup_task is not None:
|
if self._cleanup_task is not None:
|
||||||
self._cleanup_task.cancel()
|
self._cleanup_task.cancel()
|
||||||
try:
|
with contextlib.suppress(asyncio.CancelledError):
|
||||||
await self._cleanup_task
|
await self._cleanup_task
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
self._cleanup_task = None
|
self._cleanup_task = None
|
||||||
await self.clear()
|
await self.clear()
|
||||||
|
|
||||||
|
|||||||
@@ -3,27 +3,30 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import contextlib
|
||||||
import json
|
import json
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from typing import TYPE_CHECKING, Any
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from fastapi_traffic.backends.base import Backend
|
from fastapi_traffic.backends.base import Backend
|
||||||
from fastapi_traffic.exceptions import BackendError
|
from fastapi_traffic.exceptions import BackendError
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
class SQLiteBackend(Backend):
|
class SQLiteBackend(Backend):
|
||||||
"""SQLite-based backend with connection pooling and async support."""
|
"""SQLite-based backend with connection pooling and async support."""
|
||||||
|
|
||||||
__slots__ = (
|
__slots__ = (
|
||||||
"_db_path",
|
|
||||||
"_connection",
|
|
||||||
"_lock",
|
|
||||||
"_cleanup_interval",
|
"_cleanup_interval",
|
||||||
"_cleanup_task",
|
"_cleanup_task",
|
||||||
"_pool_size",
|
"_connection",
|
||||||
"_connections",
|
"_connections",
|
||||||
|
"_db_path",
|
||||||
|
"_lock",
|
||||||
|
"_pool_size",
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -59,9 +62,7 @@ class SQLiteBackend(Backend):
|
|||||||
"""Ensure a database connection exists."""
|
"""Ensure a database connection exists."""
|
||||||
if self._connection is None:
|
if self._connection is None:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
self._connection = await loop.run_in_executor(
|
self._connection = await loop.run_in_executor(None, self._create_connection)
|
||||||
None, self._create_connection
|
|
||||||
)
|
|
||||||
assert self._connection is not None
|
assert self._connection is not None
|
||||||
return self._connection
|
return self._connection
|
||||||
|
|
||||||
@@ -87,16 +88,20 @@ class SQLiteBackend(Backend):
|
|||||||
|
|
||||||
def _create_tables_sync(self, conn: sqlite3.Connection) -> None:
|
def _create_tables_sync(self, conn: sqlite3.Connection) -> None:
|
||||||
"""Synchronously create tables."""
|
"""Synchronously create tables."""
|
||||||
conn.execute("""
|
conn.execute(
|
||||||
|
"""
|
||||||
CREATE TABLE IF NOT EXISTS rate_limits (
|
CREATE TABLE IF NOT EXISTS rate_limits (
|
||||||
key TEXT PRIMARY KEY,
|
key TEXT PRIMARY KEY,
|
||||||
data TEXT NOT NULL,
|
data TEXT NOT NULL,
|
||||||
expires_at REAL NOT NULL
|
expires_at REAL NOT NULL
|
||||||
)
|
)
|
||||||
""")
|
"""
|
||||||
conn.execute("""
|
)
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
CREATE INDEX IF NOT EXISTS idx_expires_at ON rate_limits(expires_at)
|
CREATE INDEX IF NOT EXISTS idx_expires_at ON rate_limits(expires_at)
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
async def _cleanup_loop(self) -> None:
|
async def _cleanup_loop(self) -> None:
|
||||||
"""Background task to clean up expired entries."""
|
"""Background task to clean up expired entries."""
|
||||||
@@ -247,10 +252,8 @@ class SQLiteBackend(Backend):
|
|||||||
"""Close the database connection."""
|
"""Close the database connection."""
|
||||||
if self._cleanup_task is not None:
|
if self._cleanup_task is not None:
|
||||||
self._cleanup_task.cancel()
|
self._cleanup_task.cancel()
|
||||||
try:
|
with contextlib.suppress(asyncio.CancelledError):
|
||||||
await self._cleanup_task
|
await self._cleanup_task
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
self._cleanup_task = None
|
self._cleanup_task = None
|
||||||
|
|
||||||
if self._connection is not None:
|
if self._connection is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user