"""Tests for rate limiting middleware. Comprehensive tests covering: - Basic middleware functionality - Path exemptions - IP exemptions - Custom key extractors - Different algorithms - Error handling and skip_on_error - Header inclusion - Multiple middleware configurations """ from __future__ import annotations from typing import TYPE_CHECKING import pytest from fastapi import FastAPI from httpx import ASGITransport, AsyncClient from fastapi_traffic import MemoryBackend from fastapi_traffic.middleware import ( RateLimitMiddleware, SlidingWindowMiddleware, TokenBucketMiddleware, ) if TYPE_CHECKING: from collections.abc import AsyncGenerator class TestRateLimitMiddleware: """Tests for RateLimitMiddleware.""" @pytest.fixture def app(self) -> FastAPI: """Create a test app with rate limit middleware.""" app = FastAPI() backend = MemoryBackend() app.add_middleware( RateLimitMiddleware, limit=5, window_size=60, backend=backend, ) @app.get("/api/resource") async def resource() -> dict[str, str]: return {"status": "ok"} @app.post("/api/create") async def create() -> dict[str, str]: return {"status": "created"} return app @pytest.fixture async def client(self, app: FastAPI) -> AsyncGenerator[AsyncClient, None]: """Create test client.""" transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: yield client async def test_allows_requests_within_limit(self, client: AsyncClient) -> None: """Test that requests within limit are allowed.""" for i in range(5): response = await client.get("/api/resource") assert response.status_code == 200, f"Request {i} should succeed" async def test_blocks_requests_over_limit(self, client: AsyncClient) -> None: """Test that requests over limit are blocked.""" for _ in range(5): await client.get("/api/resource") response = await client.get("/api/resource") assert response.status_code == 429 async def test_rate_limit_headers_included(self, client: AsyncClient) -> None: """Test that rate limit headers are included.""" response = await client.get("/api/resource") assert "X-RateLimit-Limit" in response.headers assert "X-RateLimit-Remaining" in response.headers assert "X-RateLimit-Reset" in response.headers async def test_different_endpoints_counted_separately( self, client: AsyncClient ) -> None: """Test that different endpoints are counted separately by path.""" # Middleware includes path in the key by default for _ in range(3): response = await client.get("/api/resource") assert response.status_code == 200 for _ in range(2): response = await client.post("/api/create") assert response.status_code == 200 async def test_rate_limit_response_format(self, client: AsyncClient) -> None: """Test rate limit response format.""" for _ in range(5): await client.get("/api/resource") response = await client.get("/api/resource") assert response.status_code == 429 data = response.json() assert "detail" in data assert "retry_after" in data class TestMiddlewareExemptions: """Tests for middleware path and IP exemptions.""" @pytest.fixture def app(self) -> FastAPI: """Create app with exemptions configured.""" app = FastAPI() backend = MemoryBackend() app.add_middleware( RateLimitMiddleware, limit=2, window_size=60, backend=backend, exempt_paths={"/health", "/metrics"}, exempt_ips={"10.0.0.1"}, ) @app.get("/health") async def health() -> dict[str, str]: return {"status": "healthy"} @app.get("/metrics") async def metrics() -> dict[str, str]: return {"requests": "100"} @app.get("/api/data") async def data() -> dict[str, str]: return {"data": "value"} return app @pytest.fixture async def client(self, app: FastAPI) -> AsyncGenerator[AsyncClient, None]: """Create test client.""" transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: yield client async def test_exempt_paths_bypass_limit(self, client: AsyncClient) -> None: """Test that exempt paths bypass rate limiting.""" for _ in range(10): response = await client.get("/health") assert response.status_code == 200 async def test_multiple_exempt_paths(self, client: AsyncClient) -> None: """Test multiple exempt paths work correctly.""" for _ in range(5): response = await client.get("/health") assert response.status_code == 200 for _ in range(5): response = await client.get("/metrics") assert response.status_code == 200 async def test_non_exempt_paths_are_limited(self, client: AsyncClient) -> None: """Test that non-exempt paths are rate limited.""" for _ in range(2): response = await client.get("/api/data") assert response.status_code == 200 response = await client.get("/api/data") assert response.status_code == 429 async def test_exempt_paths_dont_consume_limit(self, client: AsyncClient) -> None: """Test that exempt path requests don't consume rate limit.""" for _ in range(10): await client.get("/health") for _ in range(2): response = await client.get("/api/data") assert response.status_code == 200 class TestMiddlewareCustomKeyExtractor: """Tests for middleware with custom key extractor.""" @pytest.fixture def app(self) -> FastAPI: """Create app with custom key extractor.""" app = FastAPI() backend = MemoryBackend() def user_id_extractor(request: object) -> str: headers = getattr(request, "headers", {}) if hasattr(headers, "get"): return headers.get("X-User-ID", "anonymous") return "anonymous" app.add_middleware( RateLimitMiddleware, limit=3, window_size=60, backend=backend, key_extractor=user_id_extractor, ) @app.get("/api/resource") async def resource() -> dict[str, str]: return {"status": "ok"} return app @pytest.fixture async def client(self, app: FastAPI) -> AsyncGenerator[AsyncClient, None]: """Create test client.""" transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: yield client async def test_different_users_have_separate_limits( self, client: AsyncClient ) -> None: """Test that different users have separate rate limits.""" for _ in range(3): response = await client.get( "/api/resource", headers={"X-User-ID": "user-1"} ) assert response.status_code == 200 response = await client.get("/api/resource", headers={"X-User-ID": "user-1"}) assert response.status_code == 429 response = await client.get("/api/resource", headers={"X-User-ID": "user-2"}) assert response.status_code == 200 class TestConvenienceMiddleware: """Tests for convenience middleware classes.""" async def test_sliding_window_middleware(self) -> None: """Test SlidingWindowMiddleware.""" app = FastAPI() backend = MemoryBackend() app.add_middleware( SlidingWindowMiddleware, limit=3, window_size=60, backend=backend, ) @app.get("/test") async def test_endpoint() -> dict[str, str]: return {"status": "ok"} transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: for _ in range(3): response = await client.get("/test") assert response.status_code == 200 response = await client.get("/test") assert response.status_code == 429 async def test_token_bucket_middleware(self) -> None: """Test TokenBucketMiddleware.""" app = FastAPI() backend = MemoryBackend() app.add_middleware( TokenBucketMiddleware, limit=3, window_size=60, backend=backend, ) @app.get("/test") async def test_endpoint() -> dict[str, str]: return {"status": "ok"} transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: for _ in range(3): response = await client.get("/test") assert response.status_code == 200 response = await client.get("/test") assert response.status_code == 429 class TestMiddlewareErrorHandling: """Tests for middleware error handling.""" @pytest.fixture def app_skip_on_error(self) -> FastAPI: """Create app with skip_on_error enabled.""" app = FastAPI() backend = MemoryBackend() app.add_middleware( RateLimitMiddleware, limit=5, window_size=60, backend=backend, skip_on_error=True, ) @app.get("/api/resource") async def resource() -> dict[str, str]: return {"status": "ok"} return app @pytest.fixture async def client( self, app_skip_on_error: FastAPI ) -> AsyncGenerator[AsyncClient, None]: """Create test client.""" transport = ASGITransport(app=app_skip_on_error) async with AsyncClient(transport=transport, base_url="http://test") as client: yield client async def test_normal_operation_with_skip_on_error( self, client: AsyncClient ) -> None: """Test normal operation when skip_on_error is enabled.""" for i in range(5): response = await client.get("/api/resource") assert response.status_code == 200, f"Request {i} should succeed" response = await client.get("/api/resource") assert response.status_code == 429 class TestMiddlewareHeaderConfiguration: """Tests for middleware header configuration.""" async def test_headers_disabled(self) -> None: """Test that headers can be disabled.""" app = FastAPI() backend = MemoryBackend() app.add_middleware( RateLimitMiddleware, limit=5, window_size=60, backend=backend, include_headers=False, ) @app.get("/test") async def test_endpoint() -> dict[str, str]: return {"status": "ok"} transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: response = await client.get("/test") assert response.status_code == 200 assert "X-RateLimit-Limit" not in response.headers async def test_custom_error_message(self) -> None: """Test custom error message in middleware.""" app = FastAPI() backend = MemoryBackend() app.add_middleware( RateLimitMiddleware, limit=1, window_size=60, backend=backend, error_message="Custom: Too many requests", ) @app.get("/test") async def test_endpoint() -> dict[str, str]: return {"status": "ok"} transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: await client.get("/test") response = await client.get("/test") assert response.status_code == 429 data = response.json() assert data["detail"] == "Custom: Too many requests"