- Core rate limiting with multiple algorithms (sliding window, token bucket, etc.) - SQLite and memory backends - Decorator and dependency injection patterns - Middleware support - Example usage files
212 lines
6.5 KiB
Python
212 lines
6.5 KiB
Python
"""Tests for rate limiting algorithms."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import AsyncGenerator
|
|
|
|
import pytest
|
|
|
|
from fastapi_traffic.backends.memory import MemoryBackend
|
|
from fastapi_traffic.core.algorithms import (
|
|
Algorithm,
|
|
FixedWindowAlgorithm,
|
|
LeakyBucketAlgorithm,
|
|
SlidingWindowAlgorithm,
|
|
SlidingWindowCounterAlgorithm,
|
|
TokenBucketAlgorithm,
|
|
get_algorithm,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
async def backend() -> AsyncGenerator[MemoryBackend, None]:
|
|
"""Create a memory backend for testing."""
|
|
backend = MemoryBackend()
|
|
yield backend
|
|
await backend.close()
|
|
|
|
|
|
class TestTokenBucketAlgorithm:
|
|
"""Tests for TokenBucketAlgorithm."""
|
|
|
|
async def test_allows_requests_within_limit(
|
|
self, backend: MemoryBackend
|
|
) -> None:
|
|
"""Test that requests within limit are allowed."""
|
|
algo = TokenBucketAlgorithm(10, 60.0, backend)
|
|
|
|
for i in range(10):
|
|
allowed, _ = await algo.check(f"key_{i % 2}")
|
|
assert allowed, f"Request {i} should be allowed"
|
|
|
|
async def test_blocks_requests_over_limit(
|
|
self, backend: MemoryBackend
|
|
) -> None:
|
|
"""Test that requests over limit are blocked."""
|
|
algo = TokenBucketAlgorithm(3, 60.0, backend)
|
|
|
|
for _ in range(3):
|
|
allowed, _ = await algo.check("test_key")
|
|
assert allowed
|
|
|
|
allowed, info = await algo.check("test_key")
|
|
assert not allowed
|
|
assert info.retry_after is not None
|
|
assert info.retry_after > 0
|
|
|
|
async def test_reset(self, backend: MemoryBackend) -> None:
|
|
"""Test reset functionality."""
|
|
algo = TokenBucketAlgorithm(3, 60.0, backend)
|
|
|
|
for _ in range(3):
|
|
await algo.check("test_key")
|
|
|
|
allowed, _ = await algo.check("test_key")
|
|
assert not allowed
|
|
|
|
await algo.reset("test_key")
|
|
|
|
allowed, _ = await algo.check("test_key")
|
|
assert allowed
|
|
|
|
|
|
class TestSlidingWindowAlgorithm:
|
|
"""Tests for SlidingWindowAlgorithm."""
|
|
|
|
async def test_allows_requests_within_limit(
|
|
self, backend: MemoryBackend
|
|
) -> None:
|
|
"""Test that requests within limit are allowed."""
|
|
algo = SlidingWindowAlgorithm(5, 60.0, backend)
|
|
|
|
for _ in range(5):
|
|
allowed, _ = await algo.check("test_key")
|
|
assert allowed
|
|
|
|
async def test_blocks_requests_over_limit(
|
|
self, backend: MemoryBackend
|
|
) -> None:
|
|
"""Test that requests over limit are blocked."""
|
|
algo = SlidingWindowAlgorithm(3, 60.0, backend)
|
|
|
|
for _ in range(3):
|
|
allowed, _ = await algo.check("test_key")
|
|
assert allowed
|
|
|
|
allowed, info = await algo.check("test_key")
|
|
assert not allowed
|
|
assert info.remaining == 0
|
|
|
|
|
|
class TestFixedWindowAlgorithm:
|
|
"""Tests for FixedWindowAlgorithm."""
|
|
|
|
async def test_allows_requests_within_limit(
|
|
self, backend: MemoryBackend
|
|
) -> None:
|
|
"""Test that requests within limit are allowed."""
|
|
algo = FixedWindowAlgorithm(5, 60.0, backend)
|
|
|
|
for _ in range(5):
|
|
allowed, _ = await algo.check("test_key")
|
|
assert allowed
|
|
|
|
async def test_blocks_requests_over_limit(
|
|
self, backend: MemoryBackend
|
|
) -> None:
|
|
"""Test that requests over limit are blocked."""
|
|
algo = FixedWindowAlgorithm(3, 60.0, backend)
|
|
|
|
for _ in range(3):
|
|
allowed, _ = await algo.check("test_key")
|
|
assert allowed
|
|
|
|
allowed, info = await algo.check("test_key")
|
|
assert not allowed
|
|
assert info.remaining == 0
|
|
|
|
|
|
class TestLeakyBucketAlgorithm:
|
|
"""Tests for LeakyBucketAlgorithm."""
|
|
|
|
async def test_allows_requests_within_limit(
|
|
self, backend: MemoryBackend
|
|
) -> None:
|
|
"""Test that requests within limit are allowed."""
|
|
algo = LeakyBucketAlgorithm(5, 60.0, backend)
|
|
|
|
for _ in range(5):
|
|
allowed, _ = await algo.check("test_key")
|
|
assert allowed
|
|
|
|
async def test_blocks_requests_over_limit(
|
|
self, backend: MemoryBackend
|
|
) -> None:
|
|
"""Test that requests over limit are blocked."""
|
|
algo = LeakyBucketAlgorithm(3, 60.0, backend)
|
|
|
|
for _ in range(3):
|
|
allowed, _ = await algo.check("test_key")
|
|
assert allowed
|
|
|
|
allowed, _ = await algo.check("test_key")
|
|
assert not allowed
|
|
|
|
|
|
class TestSlidingWindowCounterAlgorithm:
|
|
"""Tests for SlidingWindowCounterAlgorithm."""
|
|
|
|
async def test_allows_requests_within_limit(
|
|
self, backend: MemoryBackend
|
|
) -> None:
|
|
"""Test that requests within limit are allowed."""
|
|
algo = SlidingWindowCounterAlgorithm(5, 60.0, backend)
|
|
|
|
for _ in range(5):
|
|
allowed, _ = await algo.check("test_key")
|
|
assert allowed
|
|
|
|
async def test_blocks_requests_over_limit(
|
|
self, backend: MemoryBackend
|
|
) -> None:
|
|
"""Test that requests over limit are blocked."""
|
|
algo = SlidingWindowCounterAlgorithm(3, 60.0, backend)
|
|
|
|
for _ in range(3):
|
|
allowed, _ = await algo.check("test_key")
|
|
assert allowed
|
|
|
|
allowed, _ = await algo.check("test_key")
|
|
assert not allowed
|
|
|
|
|
|
class TestGetAlgorithm:
|
|
"""Tests for get_algorithm factory function."""
|
|
|
|
async def test_get_token_bucket(self, backend: MemoryBackend) -> None:
|
|
"""Test getting token bucket algorithm."""
|
|
algo = get_algorithm(Algorithm.TOKEN_BUCKET, 10, 60.0, backend)
|
|
assert isinstance(algo, TokenBucketAlgorithm)
|
|
|
|
async def test_get_sliding_window(self, backend: MemoryBackend) -> None:
|
|
"""Test getting sliding window algorithm."""
|
|
algo = get_algorithm(Algorithm.SLIDING_WINDOW, 10, 60.0, backend)
|
|
assert isinstance(algo, SlidingWindowAlgorithm)
|
|
|
|
async def test_get_fixed_window(self, backend: MemoryBackend) -> None:
|
|
"""Test getting fixed window algorithm."""
|
|
algo = get_algorithm(Algorithm.FIXED_WINDOW, 10, 60.0, backend)
|
|
assert isinstance(algo, FixedWindowAlgorithm)
|
|
|
|
async def test_get_leaky_bucket(self, backend: MemoryBackend) -> None:
|
|
"""Test getting leaky bucket algorithm."""
|
|
algo = get_algorithm(Algorithm.LEAKY_BUCKET, 10, 60.0, backend)
|
|
assert isinstance(algo, LeakyBucketAlgorithm)
|
|
|
|
async def test_get_sliding_window_counter(
|
|
self, backend: MemoryBackend
|
|
) -> None:
|
|
"""Test getting sliding window counter algorithm."""
|
|
algo = get_algorithm(Algorithm.SLIDING_WINDOW_COUNTER, 10, 60.0, backend)
|
|
assert isinstance(algo, SlidingWindowCounterAlgorithm)
|