Covers all algorithms, backends, decorators, middleware, and integration scenarios. Added conftest.py with shared fixtures and pytest-asyncio configuration.
434 lines
15 KiB
Python
434 lines
15 KiB
Python
"""Tests for rate limit backends.
|
|
|
|
Comprehensive tests covering:
|
|
- Basic CRUD operations
|
|
- TTL expiration behavior
|
|
- Concurrent access and race conditions
|
|
- LRU eviction (memory backend)
|
|
- Connection management
|
|
- Statistics and monitoring
|
|
- Error handling and edge cases
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from typing import AsyncGenerator
|
|
|
|
import pytest
|
|
|
|
from fastapi_traffic.backends.memory import MemoryBackend
|
|
from fastapi_traffic.backends.sqlite import SQLiteBackend
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
class TestMemoryBackend:
|
|
"""Tests for MemoryBackend."""
|
|
|
|
@pytest.fixture
|
|
async def backend(self) -> AsyncGenerator[MemoryBackend, None]:
|
|
"""Create a memory backend for testing."""
|
|
backend = MemoryBackend(max_size=100, cleanup_interval=1.0)
|
|
yield backend
|
|
await backend.close()
|
|
|
|
async def test_set_and_get(self, backend: MemoryBackend) -> None:
|
|
"""Test basic set and get operations."""
|
|
await backend.set("test_key", {"count": 5}, ttl=60.0)
|
|
result = await backend.get("test_key")
|
|
assert result is not None
|
|
assert result["count"] == 5
|
|
|
|
async def test_get_nonexistent(self, backend: MemoryBackend) -> None:
|
|
"""Test getting a nonexistent key."""
|
|
result = await backend.get("nonexistent")
|
|
assert result is None
|
|
|
|
async def test_delete(self, backend: MemoryBackend) -> None:
|
|
"""Test delete operation."""
|
|
await backend.set("test_key", {"count": 5}, ttl=60.0)
|
|
await backend.delete("test_key")
|
|
result = await backend.get("test_key")
|
|
assert result is None
|
|
|
|
async def test_exists(self, backend: MemoryBackend) -> None:
|
|
"""Test exists operation."""
|
|
assert not await backend.exists("test_key")
|
|
await backend.set("test_key", {"count": 5}, ttl=60.0)
|
|
assert await backend.exists("test_key")
|
|
|
|
async def test_increment(self, backend: MemoryBackend) -> None:
|
|
"""Test increment operation."""
|
|
await backend.set("test_key", {"count": 5}, ttl=60.0)
|
|
result = await backend.increment("test_key", 3)
|
|
assert result == 8
|
|
|
|
async def test_clear(self, backend: MemoryBackend) -> None:
|
|
"""Test clear operation."""
|
|
await backend.set("key1", {"count": 1}, ttl=60.0)
|
|
await backend.set("key2", {"count": 2}, ttl=60.0)
|
|
await backend.clear()
|
|
assert not await backend.exists("key1")
|
|
assert not await backend.exists("key2")
|
|
|
|
async def test_ttl_expiration(self, backend: MemoryBackend) -> None:
|
|
"""Test that entries expire after TTL."""
|
|
await backend.set("test_key", {"count": 5}, ttl=0.1)
|
|
await asyncio.sleep(0.2)
|
|
result = await backend.get("test_key")
|
|
assert result is None
|
|
|
|
async def test_lru_eviction(self) -> None:
|
|
"""Test LRU eviction when max size is reached."""
|
|
backend = MemoryBackend(max_size=3)
|
|
try:
|
|
await backend.set("key1", {"v": 1}, ttl=60.0)
|
|
await backend.set("key2", {"v": 2}, ttl=60.0)
|
|
await backend.set("key3", {"v": 3}, ttl=60.0)
|
|
await backend.set("key4", {"v": 4}, ttl=60.0)
|
|
|
|
assert not await backend.exists("key1")
|
|
assert await backend.exists("key2")
|
|
assert await backend.exists("key3")
|
|
assert await backend.exists("key4")
|
|
finally:
|
|
await backend.close()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
class TestSQLiteBackend:
|
|
"""Tests for SQLiteBackend."""
|
|
|
|
@pytest.fixture
|
|
async def backend(self) -> AsyncGenerator[SQLiteBackend, None]:
|
|
"""Create an in-memory SQLite backend for testing."""
|
|
backend = SQLiteBackend(":memory:", cleanup_interval=1.0)
|
|
await backend.initialize()
|
|
yield backend
|
|
await backend.close()
|
|
|
|
async def test_set_and_get(self, backend: SQLiteBackend) -> None:
|
|
"""Test basic set and get operations."""
|
|
await backend.set("test_key", {"count": 5}, ttl=60.0)
|
|
result = await backend.get("test_key")
|
|
assert result is not None
|
|
assert result["count"] == 5
|
|
|
|
async def test_get_nonexistent(self, backend: SQLiteBackend) -> None:
|
|
"""Test getting a nonexistent key."""
|
|
result = await backend.get("nonexistent")
|
|
assert result is None
|
|
|
|
async def test_delete(self, backend: SQLiteBackend) -> None:
|
|
"""Test delete operation."""
|
|
await backend.set("test_key", {"count": 5}, ttl=60.0)
|
|
await backend.delete("test_key")
|
|
result = await backend.get("test_key")
|
|
assert result is None
|
|
|
|
async def test_exists(self, backend: SQLiteBackend) -> None:
|
|
"""Test exists operation."""
|
|
assert not await backend.exists("test_key")
|
|
await backend.set("test_key", {"count": 5}, ttl=60.0)
|
|
assert await backend.exists("test_key")
|
|
|
|
async def test_increment(self, backend: SQLiteBackend) -> None:
|
|
"""Test increment operation."""
|
|
await backend.set("test_key", {"count": 5}, ttl=60.0)
|
|
result = await backend.increment("test_key", 3)
|
|
assert result == 8
|
|
|
|
async def test_clear(self, backend: SQLiteBackend) -> None:
|
|
"""Test clear operation."""
|
|
await backend.set("key1", {"count": 1}, ttl=60.0)
|
|
await backend.set("key2", {"count": 2}, ttl=60.0)
|
|
await backend.clear()
|
|
assert not await backend.exists("key1")
|
|
assert not await backend.exists("key2")
|
|
|
|
async def test_get_stats(self, backend: SQLiteBackend) -> None:
|
|
"""Test get_stats operation."""
|
|
await backend.set("key1", {"count": 1}, ttl=60.0)
|
|
await backend.set("key2", {"count": 2}, ttl=60.0)
|
|
stats = await backend.get_stats()
|
|
assert stats["total_entries"] == 2
|
|
assert stats["active_entries"] == 2
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
class TestMemoryBackendAdvanced:
|
|
"""Advanced tests for MemoryBackend."""
|
|
|
|
async def test_concurrent_writes(self) -> None:
|
|
"""Test concurrent write operations don't corrupt data."""
|
|
backend = MemoryBackend(max_size=1000)
|
|
try:
|
|
async def write_key(i: int) -> None:
|
|
await backend.set(f"key_{i}", {"value": i}, ttl=60.0)
|
|
|
|
await asyncio.gather(*[write_key(i) for i in range(100)])
|
|
|
|
for i in range(100):
|
|
result = await backend.get(f"key_{i}")
|
|
assert result is not None
|
|
assert result["value"] == i
|
|
finally:
|
|
await backend.close()
|
|
|
|
async def test_concurrent_increments(self) -> None:
|
|
"""Test concurrent increment operations are atomic."""
|
|
backend = MemoryBackend()
|
|
try:
|
|
await backend.set("counter", {"count": 0}, ttl=60.0)
|
|
|
|
async def increment() -> int:
|
|
return await backend.increment("counter", 1)
|
|
|
|
results = await asyncio.gather(*[increment() for _ in range(50)])
|
|
|
|
assert len(set(results)) == 50
|
|
assert max(results) == 50
|
|
finally:
|
|
await backend.close()
|
|
|
|
async def test_lru_eviction_order(self) -> None:
|
|
"""Test that LRU eviction removes oldest entries first."""
|
|
backend = MemoryBackend(max_size=3)
|
|
try:
|
|
await backend.set("key1", {"v": 1}, ttl=60.0)
|
|
await backend.set("key2", {"v": 2}, ttl=60.0)
|
|
await backend.set("key3", {"v": 3}, ttl=60.0)
|
|
|
|
await backend.get("key1")
|
|
|
|
await backend.set("key4", {"v": 4}, ttl=60.0)
|
|
|
|
assert await backend.exists("key1")
|
|
assert not await backend.exists("key2")
|
|
assert await backend.exists("key3")
|
|
assert await backend.exists("key4")
|
|
finally:
|
|
await backend.close()
|
|
|
|
async def test_cleanup_task_removes_expired(self) -> None:
|
|
"""Test that background cleanup removes expired entries."""
|
|
backend = MemoryBackend(max_size=100, cleanup_interval=0.1)
|
|
try:
|
|
await backend.start_cleanup()
|
|
await backend.set("expire_soon", {"v": 1}, ttl=0.05)
|
|
await backend.set("keep", {"v": 2}, ttl=60.0)
|
|
|
|
assert await backend.exists("expire_soon")
|
|
|
|
await asyncio.sleep(0.2)
|
|
|
|
assert not await backend.exists("expire_soon")
|
|
assert await backend.exists("keep")
|
|
finally:
|
|
await backend.close()
|
|
|
|
async def test_get_stats(self) -> None:
|
|
"""Test get_stats returns accurate information."""
|
|
backend = MemoryBackend(max_size=100)
|
|
try:
|
|
await backend.set("key1", {"v": 1}, ttl=60.0)
|
|
await backend.set("key2", {"v": 2}, ttl=60.0)
|
|
await backend.set("key3", {"v": 3}, ttl=60.0)
|
|
|
|
stats = await backend.get_stats()
|
|
assert stats["total_keys"] == 3
|
|
assert stats["max_size"] == 100
|
|
assert stats["backend"] == "memory"
|
|
finally:
|
|
await backend.close()
|
|
|
|
async def test_ping_always_returns_true(self) -> None:
|
|
"""Test that ping returns True for memory backend."""
|
|
backend = MemoryBackend()
|
|
try:
|
|
assert await backend.ping() is True
|
|
finally:
|
|
await backend.close()
|
|
|
|
async def test_context_manager(self) -> None:
|
|
"""Test async context manager usage."""
|
|
async with MemoryBackend() as backend:
|
|
await backend.set("key", {"v": 1}, ttl=60.0)
|
|
result = await backend.get("key")
|
|
assert result is not None
|
|
|
|
async def test_len_returns_entry_count(self) -> None:
|
|
"""Test __len__ returns correct count."""
|
|
backend = MemoryBackend()
|
|
try:
|
|
assert len(backend) == 0
|
|
await backend.set("key1", {"v": 1}, ttl=60.0)
|
|
assert len(backend) == 1
|
|
await backend.set("key2", {"v": 2}, ttl=60.0)
|
|
assert len(backend) == 2
|
|
await backend.delete("key1")
|
|
assert len(backend) == 1
|
|
finally:
|
|
await backend.close()
|
|
|
|
async def test_update_existing_key(self) -> None:
|
|
"""Test updating an existing key."""
|
|
backend = MemoryBackend()
|
|
try:
|
|
await backend.set("key", {"v": 1}, ttl=60.0)
|
|
await backend.set("key", {"v": 2}, ttl=60.0)
|
|
result = await backend.get("key")
|
|
assert result is not None
|
|
assert result["v"] == 2
|
|
finally:
|
|
await backend.close()
|
|
|
|
async def test_increment_nonexistent_key(self) -> None:
|
|
"""Test incrementing a key that doesn't exist."""
|
|
backend = MemoryBackend()
|
|
try:
|
|
result = await backend.increment("nonexistent", 5)
|
|
assert result == 5
|
|
finally:
|
|
await backend.close()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
class TestSQLiteBackendAdvanced:
|
|
"""Advanced tests for SQLiteBackend."""
|
|
|
|
async def test_concurrent_writes(self) -> None:
|
|
"""Test concurrent write operations."""
|
|
backend = SQLiteBackend(":memory:")
|
|
await backend.initialize()
|
|
try:
|
|
async def write_key(i: int) -> None:
|
|
await backend.set(f"key_{i}", {"value": i}, ttl=60.0)
|
|
|
|
await asyncio.gather(*[write_key(i) for i in range(50)])
|
|
|
|
for i in range(50):
|
|
result = await backend.get(f"key_{i}")
|
|
assert result is not None
|
|
assert result["value"] == i
|
|
finally:
|
|
await backend.close()
|
|
|
|
async def test_persistence_across_operations(self) -> None:
|
|
"""Test that data persists correctly."""
|
|
backend = SQLiteBackend(":memory:")
|
|
await backend.initialize()
|
|
try:
|
|
await backend.set("persist_key", {"data": "test"}, ttl=3600.0)
|
|
|
|
await backend.set("other_key", {"data": "other"}, ttl=3600.0)
|
|
await backend.delete("other_key")
|
|
|
|
result = await backend.get("persist_key")
|
|
assert result is not None
|
|
assert result["data"] == "test"
|
|
finally:
|
|
await backend.close()
|
|
|
|
async def test_ttl_expiration(self) -> None:
|
|
"""Test TTL expiration in SQLite backend."""
|
|
backend = SQLiteBackend(":memory:")
|
|
await backend.initialize()
|
|
try:
|
|
await backend.set("expire_key", {"v": 1}, ttl=0.1)
|
|
assert await backend.exists("expire_key")
|
|
|
|
await asyncio.sleep(0.15)
|
|
|
|
result = await backend.get("expire_key")
|
|
assert result is None
|
|
finally:
|
|
await backend.close()
|
|
|
|
async def test_get_stats_detailed(self) -> None:
|
|
"""Test get_stats returns detailed information."""
|
|
backend = SQLiteBackend(":memory:")
|
|
await backend.initialize()
|
|
try:
|
|
await backend.set("key1", {"v": 1}, ttl=60.0)
|
|
await backend.set("key2", {"v": 2}, ttl=0.01)
|
|
await asyncio.sleep(0.02)
|
|
|
|
stats = await backend.get_stats()
|
|
assert stats["total_entries"] == 2
|
|
assert stats["active_entries"] == 1
|
|
assert stats["expired_entries"] == 1
|
|
assert stats["db_path"] == ":memory:"
|
|
finally:
|
|
await backend.close()
|
|
|
|
async def test_context_manager(self) -> None:
|
|
"""Test async context manager usage."""
|
|
backend = SQLiteBackend(":memory:")
|
|
await backend.initialize()
|
|
async with backend:
|
|
await backend.set("key", {"v": 1}, ttl=60.0)
|
|
result = await backend.get("key")
|
|
assert result is not None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
class TestBackendInterface:
|
|
"""Tests to verify backend interface consistency."""
|
|
|
|
@pytest.fixture
|
|
async def backends(self) -> AsyncGenerator[list[MemoryBackend | SQLiteBackend], None]:
|
|
"""Create all backend types for testing."""
|
|
memory = MemoryBackend()
|
|
sqlite = SQLiteBackend(":memory:")
|
|
await sqlite.initialize()
|
|
|
|
yield [memory, sqlite]
|
|
|
|
await memory.close()
|
|
await sqlite.close()
|
|
|
|
async def test_all_backends_support_basic_operations(
|
|
self, backends: list[MemoryBackend | SQLiteBackend]
|
|
) -> None:
|
|
"""Test that all backends support the same basic operations."""
|
|
for backend in backends:
|
|
await backend.set("test_key", {"count": 1}, ttl=60.0)
|
|
|
|
result = await backend.get("test_key")
|
|
assert result is not None
|
|
assert result["count"] == 1
|
|
|
|
assert await backend.exists("test_key")
|
|
|
|
await backend.increment("test_key", 5)
|
|
|
|
await backend.delete("test_key")
|
|
assert not await backend.exists("test_key")
|
|
|
|
async def test_all_backends_handle_missing_keys(
|
|
self, backends: list[MemoryBackend | SQLiteBackend]
|
|
) -> None:
|
|
"""Test that all backends handle missing keys consistently."""
|
|
for backend in backends:
|
|
result = await backend.get("missing_key")
|
|
assert result is None
|
|
|
|
exists = await backend.exists("missing_key")
|
|
assert exists is False
|
|
|
|
await backend.delete("missing_key")
|
|
|
|
async def test_all_backends_support_clear(
|
|
self, backends: list[MemoryBackend | SQLiteBackend]
|
|
) -> None:
|
|
"""Test that all backends support clear operation."""
|
|
for backend in backends:
|
|
await backend.set("key1", {"v": 1}, ttl=60.0)
|
|
await backend.set("key2", {"v": 2}, ttl=60.0)
|
|
|
|
await backend.clear()
|
|
|
|
assert not await backend.exists("key1")
|
|
assert not await backend.exists("key2")
|