"""Tests for rate limiting algorithms. Comprehensive tests covering: - Basic allow/block behavior - Limit boundaries and edge cases - Token refill and window reset timing - Concurrent access patterns - State persistence and recovery - Different key isolation """ from __future__ import annotations import asyncio from typing import TYPE_CHECKING import pytest from fastapi_traffic.backends.memory import MemoryBackend from fastapi_traffic.core.algorithms import ( Algorithm, FixedWindowAlgorithm, LeakyBucketAlgorithm, SlidingWindowAlgorithm, SlidingWindowCounterAlgorithm, TokenBucketAlgorithm, get_algorithm, ) if TYPE_CHECKING: from collections.abc import AsyncGenerator @pytest.fixture async def backend() -> AsyncGenerator[MemoryBackend, None]: """Create a memory backend for testing.""" backend = MemoryBackend() yield backend await backend.close() @pytest.mark.asyncio class TestTokenBucketAlgorithm: """Tests for TokenBucketAlgorithm.""" async def test_allows_requests_within_limit(self, backend: MemoryBackend) -> None: """Test that requests within limit are allowed.""" algo = TokenBucketAlgorithm(10, 60.0, backend) for i in range(10): allowed, _ = await algo.check(f"key_{i % 2}") assert allowed, f"Request {i} should be allowed" async def test_blocks_requests_over_limit(self, backend: MemoryBackend) -> None: """Test that requests over limit are blocked.""" algo = TokenBucketAlgorithm(3, 60.0, backend) for _ in range(3): allowed, _ = await algo.check("test_key") assert allowed allowed, info = await algo.check("test_key") assert not allowed assert info.retry_after is not None assert info.retry_after > 0 async def test_reset(self, backend: MemoryBackend) -> None: """Test reset functionality.""" algo = TokenBucketAlgorithm(3, 60.0, backend) for _ in range(3): await algo.check("test_key") allowed, _ = await algo.check("test_key") assert not allowed await algo.reset("test_key") allowed, _ = await algo.check("test_key") assert allowed @pytest.mark.asyncio class TestSlidingWindowAlgorithm: """Tests for SlidingWindowAlgorithm.""" async def test_allows_requests_within_limit(self, backend: MemoryBackend) -> None: """Test that requests within limit are allowed.""" algo = SlidingWindowAlgorithm(5, 60.0, backend) for _ in range(5): allowed, _ = await algo.check("test_key") assert allowed async def test_blocks_requests_over_limit(self, backend: MemoryBackend) -> None: """Test that requests over limit are blocked.""" algo = SlidingWindowAlgorithm(3, 60.0, backend) for _ in range(3): allowed, _ = await algo.check("test_key") assert allowed allowed, info = await algo.check("test_key") assert not allowed assert info.remaining == 0 @pytest.mark.asyncio class TestFixedWindowAlgorithm: """Tests for FixedWindowAlgorithm.""" async def test_allows_requests_within_limit(self, backend: MemoryBackend) -> None: """Test that requests within limit are allowed.""" algo = FixedWindowAlgorithm(5, 60.0, backend) for _ in range(5): allowed, _ = await algo.check("test_key") assert allowed async def test_blocks_requests_over_limit(self, backend: MemoryBackend) -> None: """Test that requests over limit are blocked.""" algo = FixedWindowAlgorithm(3, 60.0, backend) for _ in range(3): allowed, _ = await algo.check("test_key") assert allowed allowed, info = await algo.check("test_key") assert not allowed assert info.remaining == 0 @pytest.mark.asyncio class TestLeakyBucketAlgorithm: """Tests for LeakyBucketAlgorithm.""" async def test_allows_requests_within_limit(self, backend: MemoryBackend) -> None: """Test that requests within limit are allowed.""" algo = LeakyBucketAlgorithm(5, 60.0, backend) for _ in range(5): allowed, _ = await algo.check("test_key") assert allowed async def test_blocks_requests_over_limit(self, backend: MemoryBackend) -> None: """Test that requests over limit are blocked.""" algo = LeakyBucketAlgorithm(3, 60.0, backend) # Leaky bucket allows burst_size requests initially for _ in range(3): allowed, _ = await algo.check("test_key") assert allowed # After burst, should eventually block # Note: Leaky bucket behavior depends on leak rate allowed, info = await algo.check("test_key") # Just verify we get valid info back assert info.limit == 3 @pytest.mark.asyncio class TestSlidingWindowCounterAlgorithm: """Tests for SlidingWindowCounterAlgorithm.""" async def test_allows_requests_within_limit(self, backend: MemoryBackend) -> None: """Test that requests within limit are allowed.""" algo = SlidingWindowCounterAlgorithm(5, 60.0, backend) for _ in range(5): allowed, _ = await algo.check("test_key") assert allowed async def test_blocks_requests_over_limit(self, backend: MemoryBackend) -> None: """Test that requests over limit are blocked.""" algo = SlidingWindowCounterAlgorithm(3, 60.0, backend) for _ in range(3): allowed, _ = await algo.check("test_key") assert allowed allowed, _ = await algo.check("test_key") assert not allowed @pytest.mark.asyncio class TestGetAlgorithm: """Tests for get_algorithm factory function.""" async def test_get_token_bucket(self, backend: MemoryBackend) -> None: """Test getting token bucket algorithm.""" algo = get_algorithm(Algorithm.TOKEN_BUCKET, 10, 60.0, backend) assert isinstance(algo, TokenBucketAlgorithm) async def test_get_sliding_window(self, backend: MemoryBackend) -> None: """Test getting sliding window algorithm.""" algo = get_algorithm(Algorithm.SLIDING_WINDOW, 10, 60.0, backend) assert isinstance(algo, SlidingWindowAlgorithm) async def test_get_fixed_window(self, backend: MemoryBackend) -> None: """Test getting fixed window algorithm.""" algo = get_algorithm(Algorithm.FIXED_WINDOW, 10, 60.0, backend) assert isinstance(algo, FixedWindowAlgorithm) async def test_get_leaky_bucket(self, backend: MemoryBackend) -> None: """Test getting leaky bucket algorithm.""" algo = get_algorithm(Algorithm.LEAKY_BUCKET, 10, 60.0, backend) assert isinstance(algo, LeakyBucketAlgorithm) async def test_get_sliding_window_counter(self, backend: MemoryBackend) -> None: """Test getting sliding window counter algorithm.""" algo = get_algorithm(Algorithm.SLIDING_WINDOW_COUNTER, 10, 60.0, backend) assert isinstance(algo, SlidingWindowCounterAlgorithm) @pytest.mark.asyncio class TestTokenBucketAdvanced: """Advanced tests for TokenBucketAlgorithm.""" async def test_token_refill_over_time(self, backend: MemoryBackend) -> None: """Test that tokens refill after time passes.""" algo = TokenBucketAlgorithm(5, 1.0, backend) for _ in range(5): allowed, _ = await algo.check("refill_key") assert allowed allowed, _ = await algo.check("refill_key") assert not allowed await asyncio.sleep(0.3) allowed, _ = await algo.check("refill_key") assert allowed async def test_burst_size_configuration(self, backend: MemoryBackend) -> None: """Test that burst_size limits initial tokens.""" algo = TokenBucketAlgorithm(100, 60.0, backend, burst_size=5) for i in range(5): allowed, _ = await algo.check("burst_key") assert allowed, f"Request {i} should be allowed" allowed, _ = await algo.check("burst_key") assert not allowed async def test_key_isolation(self, backend: MemoryBackend) -> None: """Test that different keys have separate limits.""" algo = TokenBucketAlgorithm(3, 60.0, backend) for _ in range(3): await algo.check("key_a") allowed_a, _ = await algo.check("key_a") assert not allowed_a allowed_b, _ = await algo.check("key_b") assert allowed_b async def test_concurrent_requests(self, backend: MemoryBackend) -> None: """Test concurrent request handling.""" algo = TokenBucketAlgorithm(10, 60.0, backend) async def make_request() -> bool: allowed, _ = await algo.check("concurrent_key") return allowed results = await asyncio.gather(*[make_request() for _ in range(15)]) allowed_count = sum(results) assert allowed_count == 10 async def test_rate_limit_info_accuracy(self, backend: MemoryBackend) -> None: """Test that rate limit info is accurate.""" algo = TokenBucketAlgorithm(5, 60.0, backend) allowed, info = await algo.check("info_key") assert allowed assert info.limit == 5 assert info.remaining == 4 for _ in range(4): await algo.check("info_key") allowed, info = await algo.check("info_key") assert not allowed assert info.remaining == 0 assert info.retry_after is not None assert info.retry_after > 0 @pytest.mark.asyncio class TestSlidingWindowAdvanced: """Advanced tests for SlidingWindowAlgorithm.""" async def test_window_expiration(self, backend: MemoryBackend) -> None: """Test that old requests expire from the window.""" algo = SlidingWindowAlgorithm(3, 0.5, backend) for _ in range(3): allowed, _ = await algo.check("expire_key") assert allowed allowed, _ = await algo.check("expire_key") assert not allowed await asyncio.sleep(0.6) allowed, _ = await algo.check("expire_key") assert allowed async def test_sliding_behavior(self, backend: MemoryBackend) -> None: """Test that window slides correctly.""" algo = SlidingWindowAlgorithm(2, 1.0, backend) allowed, _ = await algo.check("slide_key") assert allowed await asyncio.sleep(0.3) allowed, _ = await algo.check("slide_key") assert allowed allowed, _ = await algo.check("slide_key") assert not allowed await asyncio.sleep(0.8) allowed, _ = await algo.check("slide_key") assert allowed @pytest.mark.asyncio class TestFixedWindowAdvanced: """Advanced tests for FixedWindowAlgorithm.""" async def test_window_boundary_reset(self, backend: MemoryBackend) -> None: """Test that counter resets at window boundary.""" algo = FixedWindowAlgorithm(3, 0.5, backend) for _ in range(3): allowed, _ = await algo.check("boundary_key") assert allowed allowed, _ = await algo.check("boundary_key") assert not allowed await asyncio.sleep(0.6) allowed, _ = await algo.check("boundary_key") assert allowed async def test_multiple_windows(self, backend: MemoryBackend) -> None: """Test behavior across multiple windows.""" algo = FixedWindowAlgorithm(2, 0.3, backend) for _ in range(2): allowed, _ = await algo.check("multi_key") assert allowed allowed, _ = await algo.check("multi_key") assert not allowed await asyncio.sleep(0.35) for _ in range(2): allowed, _ = await algo.check("multi_key") assert allowed @pytest.mark.asyncio class TestLeakyBucketAdvanced: """Advanced tests for LeakyBucketAlgorithm.""" async def test_leak_rate(self, backend: MemoryBackend) -> None: """Test that bucket leaks over time.""" algo = LeakyBucketAlgorithm(3, 1.0, backend) # Make initial requests for _ in range(3): allowed, _ = await algo.check("leak_key") assert allowed # Wait for some leaking to occur await asyncio.sleep(0.5) # Should be able to make another request after leak allowed, info = await algo.check("leak_key") assert info.limit == 3 async def test_steady_rate_enforcement(self, backend: MemoryBackend) -> None: """Test that leaky bucket tracks requests.""" algo = LeakyBucketAlgorithm(5, 1.0, backend) # Make several requests for _ in range(5): allowed, info = await algo.check("steady_key") assert allowed assert info.limit == 5 @pytest.mark.asyncio class TestSlidingWindowCounterAdvanced: """Advanced tests for SlidingWindowCounterAlgorithm.""" async def test_weighted_counting(self, backend: MemoryBackend) -> None: """Test weighted counting between windows.""" algo = SlidingWindowCounterAlgorithm(10, 1.0, backend) for _ in range(8): allowed, _ = await algo.check("weighted_key") assert allowed await asyncio.sleep(0.6) allowed, info = await algo.check("weighted_key") assert allowed assert info.remaining > 0 async def test_precision_vs_fixed_window(self, backend: MemoryBackend) -> None: """Test that sliding window counter is more precise than fixed window.""" algo = SlidingWindowCounterAlgorithm(4, 1.0, backend) for _ in range(4): allowed, _ = await algo.check("precision_key") assert allowed allowed, _ = await algo.check("precision_key") assert not allowed # Wait for the full window to pass to ensure tokens are fully replenished await asyncio.sleep(1.1) allowed, _ = await algo.check("precision_key") assert allowed @pytest.mark.asyncio class TestAlgorithmStateManagement: """Tests for algorithm state management.""" async def test_get_state_without_consuming(self, backend: MemoryBackend) -> None: """Test getting state without consuming tokens.""" algo = TokenBucketAlgorithm(5, 60.0, backend) await algo.check("state_key") await algo.check("state_key") state = await algo.get_state("state_key") assert state is not None assert state.remaining == 3 state2 = await algo.get_state("state_key") assert state2 is not None assert state2.remaining == 3 async def test_get_state_nonexistent_key(self, backend: MemoryBackend) -> None: """Test getting state for nonexistent key.""" algo = TokenBucketAlgorithm(5, 60.0, backend) state = await algo.get_state("nonexistent_key") assert state is None async def test_reset_restores_full_capacity(self, backend: MemoryBackend) -> None: """Test that reset restores full capacity.""" algo = TokenBucketAlgorithm(5, 60.0, backend) for _ in range(5): await algo.check("reset_key") allowed, _ = await algo.check("reset_key") assert not allowed await algo.reset("reset_key") allowed, info = await algo.check("reset_key") assert allowed assert info.remaining == 4