Initial commit: fastapi-traffic rate limiting library
- Core rate limiting with multiple algorithms (sliding window, token bucket, etc.) - SQLite and memory backends - Decorator and dependency injection patterns - Middleware support - Example usage files
This commit is contained in:
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for fastapi-traffic."""
|
||||
211
tests/test_algorithms.py
Normal file
211
tests/test_algorithms.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""Tests for rate limiting algorithms."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import pytest
|
||||
|
||||
from fastapi_traffic.backends.memory import MemoryBackend
|
||||
from fastapi_traffic.core.algorithms import (
|
||||
Algorithm,
|
||||
FixedWindowAlgorithm,
|
||||
LeakyBucketAlgorithm,
|
||||
SlidingWindowAlgorithm,
|
||||
SlidingWindowCounterAlgorithm,
|
||||
TokenBucketAlgorithm,
|
||||
get_algorithm,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def backend() -> AsyncGenerator[MemoryBackend, None]:
|
||||
"""Create a memory backend for testing."""
|
||||
backend = MemoryBackend()
|
||||
yield backend
|
||||
await backend.close()
|
||||
|
||||
|
||||
class TestTokenBucketAlgorithm:
|
||||
"""Tests for TokenBucketAlgorithm."""
|
||||
|
||||
async def test_allows_requests_within_limit(
|
||||
self, backend: MemoryBackend
|
||||
) -> None:
|
||||
"""Test that requests within limit are allowed."""
|
||||
algo = TokenBucketAlgorithm(10, 60.0, backend)
|
||||
|
||||
for i in range(10):
|
||||
allowed, _ = await algo.check(f"key_{i % 2}")
|
||||
assert allowed, f"Request {i} should be allowed"
|
||||
|
||||
async def test_blocks_requests_over_limit(
|
||||
self, backend: MemoryBackend
|
||||
) -> None:
|
||||
"""Test that requests over limit are blocked."""
|
||||
algo = TokenBucketAlgorithm(3, 60.0, backend)
|
||||
|
||||
for _ in range(3):
|
||||
allowed, _ = await algo.check("test_key")
|
||||
assert allowed
|
||||
|
||||
allowed, info = await algo.check("test_key")
|
||||
assert not allowed
|
||||
assert info.retry_after is not None
|
||||
assert info.retry_after > 0
|
||||
|
||||
async def test_reset(self, backend: MemoryBackend) -> None:
|
||||
"""Test reset functionality."""
|
||||
algo = TokenBucketAlgorithm(3, 60.0, backend)
|
||||
|
||||
for _ in range(3):
|
||||
await algo.check("test_key")
|
||||
|
||||
allowed, _ = await algo.check("test_key")
|
||||
assert not allowed
|
||||
|
||||
await algo.reset("test_key")
|
||||
|
||||
allowed, _ = await algo.check("test_key")
|
||||
assert allowed
|
||||
|
||||
|
||||
class TestSlidingWindowAlgorithm:
|
||||
"""Tests for SlidingWindowAlgorithm."""
|
||||
|
||||
async def test_allows_requests_within_limit(
|
||||
self, backend: MemoryBackend
|
||||
) -> None:
|
||||
"""Test that requests within limit are allowed."""
|
||||
algo = SlidingWindowAlgorithm(5, 60.0, backend)
|
||||
|
||||
for _ in range(5):
|
||||
allowed, _ = await algo.check("test_key")
|
||||
assert allowed
|
||||
|
||||
async def test_blocks_requests_over_limit(
|
||||
self, backend: MemoryBackend
|
||||
) -> None:
|
||||
"""Test that requests over limit are blocked."""
|
||||
algo = SlidingWindowAlgorithm(3, 60.0, backend)
|
||||
|
||||
for _ in range(3):
|
||||
allowed, _ = await algo.check("test_key")
|
||||
assert allowed
|
||||
|
||||
allowed, info = await algo.check("test_key")
|
||||
assert not allowed
|
||||
assert info.remaining == 0
|
||||
|
||||
|
||||
class TestFixedWindowAlgorithm:
|
||||
"""Tests for FixedWindowAlgorithm."""
|
||||
|
||||
async def test_allows_requests_within_limit(
|
||||
self, backend: MemoryBackend
|
||||
) -> None:
|
||||
"""Test that requests within limit are allowed."""
|
||||
algo = FixedWindowAlgorithm(5, 60.0, backend)
|
||||
|
||||
for _ in range(5):
|
||||
allowed, _ = await algo.check("test_key")
|
||||
assert allowed
|
||||
|
||||
async def test_blocks_requests_over_limit(
|
||||
self, backend: MemoryBackend
|
||||
) -> None:
|
||||
"""Test that requests over limit are blocked."""
|
||||
algo = FixedWindowAlgorithm(3, 60.0, backend)
|
||||
|
||||
for _ in range(3):
|
||||
allowed, _ = await algo.check("test_key")
|
||||
assert allowed
|
||||
|
||||
allowed, info = await algo.check("test_key")
|
||||
assert not allowed
|
||||
assert info.remaining == 0
|
||||
|
||||
|
||||
class TestLeakyBucketAlgorithm:
|
||||
"""Tests for LeakyBucketAlgorithm."""
|
||||
|
||||
async def test_allows_requests_within_limit(
|
||||
self, backend: MemoryBackend
|
||||
) -> None:
|
||||
"""Test that requests within limit are allowed."""
|
||||
algo = LeakyBucketAlgorithm(5, 60.0, backend)
|
||||
|
||||
for _ in range(5):
|
||||
allowed, _ = await algo.check("test_key")
|
||||
assert allowed
|
||||
|
||||
async def test_blocks_requests_over_limit(
|
||||
self, backend: MemoryBackend
|
||||
) -> None:
|
||||
"""Test that requests over limit are blocked."""
|
||||
algo = LeakyBucketAlgorithm(3, 60.0, backend)
|
||||
|
||||
for _ in range(3):
|
||||
allowed, _ = await algo.check("test_key")
|
||||
assert allowed
|
||||
|
||||
allowed, _ = await algo.check("test_key")
|
||||
assert not allowed
|
||||
|
||||
|
||||
class TestSlidingWindowCounterAlgorithm:
|
||||
"""Tests for SlidingWindowCounterAlgorithm."""
|
||||
|
||||
async def test_allows_requests_within_limit(
|
||||
self, backend: MemoryBackend
|
||||
) -> None:
|
||||
"""Test that requests within limit are allowed."""
|
||||
algo = SlidingWindowCounterAlgorithm(5, 60.0, backend)
|
||||
|
||||
for _ in range(5):
|
||||
allowed, _ = await algo.check("test_key")
|
||||
assert allowed
|
||||
|
||||
async def test_blocks_requests_over_limit(
|
||||
self, backend: MemoryBackend
|
||||
) -> None:
|
||||
"""Test that requests over limit are blocked."""
|
||||
algo = SlidingWindowCounterAlgorithm(3, 60.0, backend)
|
||||
|
||||
for _ in range(3):
|
||||
allowed, _ = await algo.check("test_key")
|
||||
assert allowed
|
||||
|
||||
allowed, _ = await algo.check("test_key")
|
||||
assert not allowed
|
||||
|
||||
|
||||
class TestGetAlgorithm:
|
||||
"""Tests for get_algorithm factory function."""
|
||||
|
||||
async def test_get_token_bucket(self, backend: MemoryBackend) -> None:
|
||||
"""Test getting token bucket algorithm."""
|
||||
algo = get_algorithm(Algorithm.TOKEN_BUCKET, 10, 60.0, backend)
|
||||
assert isinstance(algo, TokenBucketAlgorithm)
|
||||
|
||||
async def test_get_sliding_window(self, backend: MemoryBackend) -> None:
|
||||
"""Test getting sliding window algorithm."""
|
||||
algo = get_algorithm(Algorithm.SLIDING_WINDOW, 10, 60.0, backend)
|
||||
assert isinstance(algo, SlidingWindowAlgorithm)
|
||||
|
||||
async def test_get_fixed_window(self, backend: MemoryBackend) -> None:
|
||||
"""Test getting fixed window algorithm."""
|
||||
algo = get_algorithm(Algorithm.FIXED_WINDOW, 10, 60.0, backend)
|
||||
assert isinstance(algo, FixedWindowAlgorithm)
|
||||
|
||||
async def test_get_leaky_bucket(self, backend: MemoryBackend) -> None:
|
||||
"""Test getting leaky bucket algorithm."""
|
||||
algo = get_algorithm(Algorithm.LEAKY_BUCKET, 10, 60.0, backend)
|
||||
assert isinstance(algo, LeakyBucketAlgorithm)
|
||||
|
||||
async def test_get_sliding_window_counter(
|
||||
self, backend: MemoryBackend
|
||||
) -> None:
|
||||
"""Test getting sliding window counter algorithm."""
|
||||
algo = get_algorithm(Algorithm.SLIDING_WINDOW_COUNTER, 10, 60.0, backend)
|
||||
assert isinstance(algo, SlidingWindowCounterAlgorithm)
|
||||
143
tests/test_backends.py
Normal file
143
tests/test_backends.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""Tests for rate limit backends."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import pytest
|
||||
|
||||
from fastapi_traffic.backends.memory import MemoryBackend
|
||||
from fastapi_traffic.backends.sqlite import SQLiteBackend
|
||||
|
||||
|
||||
class TestMemoryBackend:
|
||||
"""Tests for MemoryBackend."""
|
||||
|
||||
@pytest.fixture
|
||||
async def backend(self) -> AsyncGenerator[MemoryBackend, None]:
|
||||
"""Create a memory backend for testing."""
|
||||
backend = MemoryBackend(max_size=100, cleanup_interval=1.0)
|
||||
yield backend
|
||||
await backend.close()
|
||||
|
||||
async def test_set_and_get(self, backend: MemoryBackend) -> None:
|
||||
"""Test basic set and get operations."""
|
||||
await backend.set("test_key", {"count": 5}, ttl=60.0)
|
||||
result = await backend.get("test_key")
|
||||
assert result is not None
|
||||
assert result["count"] == 5
|
||||
|
||||
async def test_get_nonexistent(self, backend: MemoryBackend) -> None:
|
||||
"""Test getting a nonexistent key."""
|
||||
result = await backend.get("nonexistent")
|
||||
assert result is None
|
||||
|
||||
async def test_delete(self, backend: MemoryBackend) -> None:
|
||||
"""Test delete operation."""
|
||||
await backend.set("test_key", {"count": 5}, ttl=60.0)
|
||||
await backend.delete("test_key")
|
||||
result = await backend.get("test_key")
|
||||
assert result is None
|
||||
|
||||
async def test_exists(self, backend: MemoryBackend) -> None:
|
||||
"""Test exists operation."""
|
||||
assert not await backend.exists("test_key")
|
||||
await backend.set("test_key", {"count": 5}, ttl=60.0)
|
||||
assert await backend.exists("test_key")
|
||||
|
||||
async def test_increment(self, backend: MemoryBackend) -> None:
|
||||
"""Test increment operation."""
|
||||
await backend.set("test_key", {"count": 5}, ttl=60.0)
|
||||
result = await backend.increment("test_key", 3)
|
||||
assert result == 8
|
||||
|
||||
async def test_clear(self, backend: MemoryBackend) -> None:
|
||||
"""Test clear operation."""
|
||||
await backend.set("key1", {"count": 1}, ttl=60.0)
|
||||
await backend.set("key2", {"count": 2}, ttl=60.0)
|
||||
await backend.clear()
|
||||
assert not await backend.exists("key1")
|
||||
assert not await backend.exists("key2")
|
||||
|
||||
async def test_ttl_expiration(self, backend: MemoryBackend) -> None:
|
||||
"""Test that entries expire after TTL."""
|
||||
await backend.set("test_key", {"count": 5}, ttl=0.1)
|
||||
await asyncio.sleep(0.2)
|
||||
result = await backend.get("test_key")
|
||||
assert result is None
|
||||
|
||||
async def test_lru_eviction(self) -> None:
|
||||
"""Test LRU eviction when max size is reached."""
|
||||
backend = MemoryBackend(max_size=3)
|
||||
try:
|
||||
await backend.set("key1", {"v": 1}, ttl=60.0)
|
||||
await backend.set("key2", {"v": 2}, ttl=60.0)
|
||||
await backend.set("key3", {"v": 3}, ttl=60.0)
|
||||
await backend.set("key4", {"v": 4}, ttl=60.0)
|
||||
|
||||
assert not await backend.exists("key1")
|
||||
assert await backend.exists("key2")
|
||||
assert await backend.exists("key3")
|
||||
assert await backend.exists("key4")
|
||||
finally:
|
||||
await backend.close()
|
||||
|
||||
|
||||
class TestSQLiteBackend:
|
||||
"""Tests for SQLiteBackend."""
|
||||
|
||||
@pytest.fixture
|
||||
async def backend(self) -> AsyncGenerator[SQLiteBackend, None]:
|
||||
"""Create an in-memory SQLite backend for testing."""
|
||||
backend = SQLiteBackend(":memory:", cleanup_interval=1.0)
|
||||
await backend.initialize()
|
||||
yield backend
|
||||
await backend.close()
|
||||
|
||||
async def test_set_and_get(self, backend: SQLiteBackend) -> None:
|
||||
"""Test basic set and get operations."""
|
||||
await backend.set("test_key", {"count": 5}, ttl=60.0)
|
||||
result = await backend.get("test_key")
|
||||
assert result is not None
|
||||
assert result["count"] == 5
|
||||
|
||||
async def test_get_nonexistent(self, backend: SQLiteBackend) -> None:
|
||||
"""Test getting a nonexistent key."""
|
||||
result = await backend.get("nonexistent")
|
||||
assert result is None
|
||||
|
||||
async def test_delete(self, backend: SQLiteBackend) -> None:
|
||||
"""Test delete operation."""
|
||||
await backend.set("test_key", {"count": 5}, ttl=60.0)
|
||||
await backend.delete("test_key")
|
||||
result = await backend.get("test_key")
|
||||
assert result is None
|
||||
|
||||
async def test_exists(self, backend: SQLiteBackend) -> None:
|
||||
"""Test exists operation."""
|
||||
assert not await backend.exists("test_key")
|
||||
await backend.set("test_key", {"count": 5}, ttl=60.0)
|
||||
assert await backend.exists("test_key")
|
||||
|
||||
async def test_increment(self, backend: SQLiteBackend) -> None:
|
||||
"""Test increment operation."""
|
||||
await backend.set("test_key", {"count": 5}, ttl=60.0)
|
||||
result = await backend.increment("test_key", 3)
|
||||
assert result == 8
|
||||
|
||||
async def test_clear(self, backend: SQLiteBackend) -> None:
|
||||
"""Test clear operation."""
|
||||
await backend.set("key1", {"count": 1}, ttl=60.0)
|
||||
await backend.set("key2", {"count": 2}, ttl=60.0)
|
||||
await backend.clear()
|
||||
assert not await backend.exists("key1")
|
||||
assert not await backend.exists("key2")
|
||||
|
||||
async def test_get_stats(self, backend: SQLiteBackend) -> None:
|
||||
"""Test get_stats operation."""
|
||||
await backend.set("key1", {"count": 1}, ttl=60.0)
|
||||
await backend.set("key2", {"count": 2}, ttl=60.0)
|
||||
stats = await backend.get_stats()
|
||||
assert stats["total_entries"] == 2
|
||||
assert stats["active_entries"] == 2
|
||||
Reference in New Issue
Block a user