"""Tests for rate limiting algorithms.""" from __future__ import annotations from typing import AsyncGenerator import pytest from fastapi_traffic.backends.memory import MemoryBackend from fastapi_traffic.core.algorithms import ( Algorithm, FixedWindowAlgorithm, LeakyBucketAlgorithm, SlidingWindowAlgorithm, SlidingWindowCounterAlgorithm, TokenBucketAlgorithm, get_algorithm, ) @pytest.fixture async def backend() -> AsyncGenerator[MemoryBackend, None]: """Create a memory backend for testing.""" backend = MemoryBackend() yield backend await backend.close() class TestTokenBucketAlgorithm: """Tests for TokenBucketAlgorithm.""" async def test_allows_requests_within_limit( self, backend: MemoryBackend ) -> None: """Test that requests within limit are allowed.""" algo = TokenBucketAlgorithm(10, 60.0, backend) for i in range(10): allowed, _ = await algo.check(f"key_{i % 2}") assert allowed, f"Request {i} should be allowed" async def test_blocks_requests_over_limit( self, backend: MemoryBackend ) -> None: """Test that requests over limit are blocked.""" algo = TokenBucketAlgorithm(3, 60.0, backend) for _ in range(3): allowed, _ = await algo.check("test_key") assert allowed allowed, info = await algo.check("test_key") assert not allowed assert info.retry_after is not None assert info.retry_after > 0 async def test_reset(self, backend: MemoryBackend) -> None: """Test reset functionality.""" algo = TokenBucketAlgorithm(3, 60.0, backend) for _ in range(3): await algo.check("test_key") allowed, _ = await algo.check("test_key") assert not allowed await algo.reset("test_key") allowed, _ = await algo.check("test_key") assert allowed class TestSlidingWindowAlgorithm: """Tests for SlidingWindowAlgorithm.""" async def test_allows_requests_within_limit( self, backend: MemoryBackend ) -> None: """Test that requests within limit are allowed.""" algo = SlidingWindowAlgorithm(5, 60.0, backend) for _ in range(5): allowed, _ = await algo.check("test_key") assert allowed async def test_blocks_requests_over_limit( self, backend: MemoryBackend ) -> None: """Test that requests over limit are blocked.""" algo = SlidingWindowAlgorithm(3, 60.0, backend) for _ in range(3): allowed, _ = await algo.check("test_key") assert allowed allowed, info = await algo.check("test_key") assert not allowed assert info.remaining == 0 class TestFixedWindowAlgorithm: """Tests for FixedWindowAlgorithm.""" async def test_allows_requests_within_limit( self, backend: MemoryBackend ) -> None: """Test that requests within limit are allowed.""" algo = FixedWindowAlgorithm(5, 60.0, backend) for _ in range(5): allowed, _ = await algo.check("test_key") assert allowed async def test_blocks_requests_over_limit( self, backend: MemoryBackend ) -> None: """Test that requests over limit are blocked.""" algo = FixedWindowAlgorithm(3, 60.0, backend) for _ in range(3): allowed, _ = await algo.check("test_key") assert allowed allowed, info = await algo.check("test_key") assert not allowed assert info.remaining == 0 class TestLeakyBucketAlgorithm: """Tests for LeakyBucketAlgorithm.""" async def test_allows_requests_within_limit( self, backend: MemoryBackend ) -> None: """Test that requests within limit are allowed.""" algo = LeakyBucketAlgorithm(5, 60.0, backend) for _ in range(5): allowed, _ = await algo.check("test_key") assert allowed async def test_blocks_requests_over_limit( self, backend: MemoryBackend ) -> None: """Test that requests over limit are blocked.""" algo = LeakyBucketAlgorithm(3, 60.0, backend) for _ in range(3): allowed, _ = await algo.check("test_key") assert allowed allowed, _ = await algo.check("test_key") assert not allowed class TestSlidingWindowCounterAlgorithm: """Tests for SlidingWindowCounterAlgorithm.""" async def test_allows_requests_within_limit( self, backend: MemoryBackend ) -> None: """Test that requests within limit are allowed.""" algo = SlidingWindowCounterAlgorithm(5, 60.0, backend) for _ in range(5): allowed, _ = await algo.check("test_key") assert allowed async def test_blocks_requests_over_limit( self, backend: MemoryBackend ) -> None: """Test that requests over limit are blocked.""" algo = SlidingWindowCounterAlgorithm(3, 60.0, backend) for _ in range(3): allowed, _ = await algo.check("test_key") assert allowed allowed, _ = await algo.check("test_key") assert not allowed class TestGetAlgorithm: """Tests for get_algorithm factory function.""" async def test_get_token_bucket(self, backend: MemoryBackend) -> None: """Test getting token bucket algorithm.""" algo = get_algorithm(Algorithm.TOKEN_BUCKET, 10, 60.0, backend) assert isinstance(algo, TokenBucketAlgorithm) async def test_get_sliding_window(self, backend: MemoryBackend) -> None: """Test getting sliding window algorithm.""" algo = get_algorithm(Algorithm.SLIDING_WINDOW, 10, 60.0, backend) assert isinstance(algo, SlidingWindowAlgorithm) async def test_get_fixed_window(self, backend: MemoryBackend) -> None: """Test getting fixed window algorithm.""" algo = get_algorithm(Algorithm.FIXED_WINDOW, 10, 60.0, backend) assert isinstance(algo, FixedWindowAlgorithm) async def test_get_leaky_bucket(self, backend: MemoryBackend) -> None: """Test getting leaky bucket algorithm.""" algo = get_algorithm(Algorithm.LEAKY_BUCKET, 10, 60.0, backend) assert isinstance(algo, LeakyBucketAlgorithm) async def test_get_sliding_window_counter( self, backend: MemoryBackend ) -> None: """Test getting sliding window counter algorithm.""" algo = get_algorithm(Algorithm.SLIDING_WINDOW_COUNTER, 10, 60.0, backend) assert isinstance(algo, SlidingWindowCounterAlgorithm)