"""Tests for rate limit decorator and dependency injection. Comprehensive tests covering: - Basic decorator functionality - Custom key extractors - Different algorithms via decorator - Cost parameter - Exemption callbacks - On-blocked callbacks - RateLimitDependency usage - Header inclusion - Error handling """ from __future__ import annotations from typing import AsyncGenerator import pytest from fastapi import FastAPI, Request from fastapi.responses import JSONResponse from httpx import ASGITransport, AsyncClient from fastapi_traffic import ( MemoryBackend, RateLimitExceeded, RateLimiter, rate_limit, ) from fastapi_traffic.core.limiter import set_limiter class TestRateLimitDecorator: """Tests for the @rate_limit decorator.""" @pytest.fixture async def setup_limiter(self) -> AsyncGenerator[RateLimiter, None]: """Set up a rate limiter for testing.""" backend = MemoryBackend() limiter = RateLimiter(backend) await limiter.initialize() set_limiter(limiter) yield limiter await limiter.close() @pytest.fixture def app(self, setup_limiter: RateLimiter) -> FastAPI: """Create a test app with rate limited endpoints.""" app = FastAPI() @app.exception_handler(RateLimitExceeded) async def handler(request: Request, exc: RateLimitExceeded) -> JSONResponse: return JSONResponse( status_code=429, content={"detail": exc.message, "retry_after": exc.retry_after}, headers=exc.limit_info.to_headers() if exc.limit_info else {}, ) @app.get("/basic") @rate_limit(3, 60) async def basic_endpoint(request: Request) -> dict[str, str]: return {"status": "ok"} @app.get("/no-headers") @rate_limit(3, window_size=60, include_headers=False) async def no_headers_endpoint(request: Request) -> dict[str, str]: return {"status": "ok"} @app.get("/custom-message") @rate_limit(2, window_size=60, error_message="Custom rate limit message") async def custom_message_endpoint(request: Request) -> dict[str, str]: return {"status": "ok"} return app @pytest.fixture async def client(self, app: FastAPI) -> AsyncGenerator[AsyncClient, None]: """Create test client.""" transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: yield client async def test_allows_requests_within_limit(self, client: AsyncClient) -> None: """Test that requests within limit are allowed.""" for i in range(3): response = await client.get("/basic") assert response.status_code == 200, f"Request {i} should succeed" async def test_blocks_requests_over_limit(self, client: AsyncClient) -> None: """Test that requests over limit are blocked.""" for _ in range(3): await client.get("/basic") response = await client.get("/basic") assert response.status_code == 429 async def test_rate_limit_enforced(self, client: AsyncClient) -> None: """Test that rate limit is enforced.""" # Use up the limit for _ in range(3): response = await client.get("/basic") assert response.status_code == 200 # Next request should be blocked response = await client.get("/basic") assert response.status_code == 429 async def test_headers_excluded_when_disabled(self, client: AsyncClient) -> None: """Test that headers are excluded when include_headers=False.""" response = await client.get("/no-headers") assert response.status_code == 200 assert "X-RateLimit-Limit" not in response.headers async def test_custom_error_message(self, client: AsyncClient) -> None: """Test custom error message is used.""" for _ in range(2): await client.get("/custom-message") response = await client.get("/custom-message") assert response.status_code == 429 data = response.json() assert data["detail"] == "Custom rate limit message" async def test_retry_after_header_on_limit(self, client: AsyncClient) -> None: """Test Retry-After header is set when rate limited.""" for _ in range(3): await client.get("/basic") response = await client.get("/basic") assert response.status_code == 429 assert "Retry-After" in response.headers class TestCustomKeyExtractor: """Tests for custom key extraction.""" @pytest.fixture async def setup_limiter(self) -> AsyncGenerator[RateLimiter, None]: """Set up a rate limiter for testing.""" backend = MemoryBackend() limiter = RateLimiter(backend) await limiter.initialize() set_limiter(limiter) yield limiter await limiter.close() @pytest.fixture def app(self, setup_limiter: RateLimiter) -> FastAPI: """Create app with custom key extractor.""" app = FastAPI() @app.exception_handler(RateLimitExceeded) async def handler(request: Request, exc: RateLimitExceeded) -> JSONResponse: return JSONResponse(status_code=429, content={"detail": exc.message}) def api_key_extractor(request: Request) -> str: return request.headers.get("X-API-Key", "anonymous") @app.get("/by-api-key") @rate_limit(2, window_size=60, key_extractor=api_key_extractor) async def by_api_key(request: Request) -> dict[str, str]: return {"status": "ok"} return app @pytest.fixture async def client(self, app: FastAPI) -> AsyncGenerator[AsyncClient, None]: """Create test client.""" transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: yield client async def test_different_keys_have_separate_limits( self, client: AsyncClient ) -> None: """Test that different API keys have separate rate limits.""" for _ in range(2): response = await client.get( "/by-api-key", headers={"X-API-Key": "key-a"} ) assert response.status_code == 200 response = await client.get("/by-api-key", headers={"X-API-Key": "key-a"}) assert response.status_code == 429 response = await client.get("/by-api-key", headers={"X-API-Key": "key-b"}) assert response.status_code == 200 async def test_anonymous_key_for_missing_header( self, client: AsyncClient ) -> None: """Test that missing API key uses anonymous.""" for _ in range(2): response = await client.get("/by-api-key") assert response.status_code == 200 response = await client.get("/by-api-key") assert response.status_code == 429 class TestExemptionCallback: """Tests for exempt_when callback.""" @pytest.fixture async def setup_limiter(self) -> AsyncGenerator[RateLimiter, None]: """Set up a rate limiter for testing.""" backend = MemoryBackend() limiter = RateLimiter(backend) await limiter.initialize() set_limiter(limiter) yield limiter await limiter.close() @pytest.fixture def app(self, setup_limiter: RateLimiter) -> FastAPI: """Create app with exemption callback.""" app = FastAPI() @app.exception_handler(RateLimitExceeded) async def handler(request: Request, exc: RateLimitExceeded) -> JSONResponse: return JSONResponse(status_code=429, content={"detail": exc.message}) def is_admin(request: Request) -> bool: return request.headers.get("X-Admin-Token") == "secret" @app.get("/with-exemption") @rate_limit(2, window_size=60, exempt_when=is_admin) async def with_exemption(request: Request) -> dict[str, str]: return {"status": "ok"} return app @pytest.fixture async def client(self, app: FastAPI) -> AsyncGenerator[AsyncClient, None]: """Create test client.""" transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: yield client async def test_exempt_requests_bypass_limit(self, client: AsyncClient) -> None: """Test that exempt requests bypass rate limiting.""" for _ in range(5): response = await client.get( "/with-exemption", headers={"X-Admin-Token": "secret"} ) assert response.status_code == 200 async def test_non_exempt_requests_are_limited(self, client: AsyncClient) -> None: """Test that non-exempt requests are rate limited.""" for _ in range(2): response = await client.get("/with-exemption") assert response.status_code == 200 response = await client.get("/with-exemption") assert response.status_code == 429 class TestCostParameter: """Tests for the cost parameter.""" @pytest.fixture async def setup_limiter(self) -> AsyncGenerator[RateLimiter, None]: """Set up a rate limiter for testing.""" backend = MemoryBackend() limiter = RateLimiter(backend) await limiter.initialize() set_limiter(limiter) yield limiter await limiter.close() @pytest.fixture def app(self, setup_limiter: RateLimiter) -> FastAPI: """Create app with cost-based endpoints.""" app = FastAPI() @app.exception_handler(RateLimitExceeded) async def handler(request: Request, exc: RateLimitExceeded) -> JSONResponse: return JSONResponse(status_code=429, content={"detail": exc.message}) @app.get("/low-cost") @rate_limit(10, window_size=60, cost=1) async def low_cost(request: Request) -> dict[str, str]: return {"status": "ok"} @app.get("/high-cost") @rate_limit(10, window_size=60, cost=5) async def high_cost(request: Request) -> dict[str, str]: return {"status": "ok"} return app @pytest.fixture async def client(self, app: FastAPI) -> AsyncGenerator[AsyncClient, None]: """Create test client.""" transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: yield client async def test_low_cost_allows_more_requests(self, client: AsyncClient) -> None: """Test that low cost endpoints allow more requests.""" for i in range(10): response = await client.get("/low-cost") assert response.status_code == 200, f"Request {i} should succeed" response = await client.get("/low-cost") assert response.status_code == 429 async def test_high_cost_allows_fewer_requests(self, client: AsyncClient) -> None: """Test that high cost endpoints allow fewer requests.""" for i in range(2): response = await client.get("/high-cost") assert response.status_code == 200, f"Request {i} should succeed" response = await client.get("/high-cost") assert response.status_code == 429