"""Integration tests for fastapi-traffic. End-to-end tests that verify the complete rate limiting flow across different configurations and usage patterns. """ from __future__ import annotations import asyncio from contextlib import asynccontextmanager from typing import TYPE_CHECKING import pytest from fastapi import FastAPI, Request from fastapi.responses import JSONResponse from httpx import ASGITransport, AsyncClient from fastapi_traffic import ( Algorithm, MemoryBackend, RateLimiter, RateLimitExceeded, rate_limit, ) from fastapi_traffic.core.config import RateLimitConfig from fastapi_traffic.core.limiter import set_limiter from fastapi_traffic.middleware import RateLimitMiddleware if TYPE_CHECKING: from collections.abc import AsyncGenerator class TestFullApplicationFlow: """Integration tests for a complete application setup.""" @pytest.fixture async def full_app(self) -> AsyncGenerator[FastAPI, None]: """Create a fully configured application.""" backend = MemoryBackend() limiter = RateLimiter(backend) @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: await limiter.initialize() set_limiter(limiter) yield await limiter.close() app = FastAPI(lifespan=lifespan) @app.exception_handler(RateLimitExceeded) async def rate_limit_handler( request: Request, exc: RateLimitExceeded ) -> JSONResponse: return JSONResponse( status_code=429, content={ "error": "rate_limit_exceeded", "message": exc.message, "retry_after": exc.retry_after, }, headers=exc.limit_info.to_headers() if exc.limit_info else {}, ) @app.get("/api/v1/users") @rate_limit(10, 60) async def list_users(request: Request) -> dict[str, object]: return {"users": [], "count": 0} @app.get("/api/v1/users/{user_id}") @rate_limit(20, 60) async def get_user(request: Request, user_id: int) -> dict[str, object]: return {"id": user_id, "name": f"User {user_id}"} @app.post("/api/v1/users") @rate_limit(5, window_size=60, cost=2) async def create_user(request: Request) -> dict[str, object]: return {"id": 1, "created": True} def get_api_key(request: Request) -> str: return request.headers.get("X-API-Key", "anonymous") @app.get("/api/v1/premium") @rate_limit(100, window_size=60, key_extractor=get_api_key) async def premium_endpoint(request: Request) -> dict[str, str]: return {"tier": "premium"} yield app @pytest.fixture async def client(self, full_app: FastAPI) -> AsyncGenerator[AsyncClient, None]: """Create test client with lifespan.""" transport = ASGITransport(app=full_app) async with AsyncClient(transport=transport, base_url="http://test") as client: yield client async def test_different_endpoints_have_separate_limits( self, client: AsyncClient ) -> None: """Test that different endpoints maintain separate rate limits.""" for _ in range(10): response = await client.get("/api/v1/users") assert response.status_code == 200 response = await client.get("/api/v1/users") assert response.status_code == 429 response = await client.get("/api/v1/users/1") assert response.status_code == 200 async def test_cost_based_limiting(self, client: AsyncClient) -> None: """Test that cost parameter affects rate limiting.""" for _ in range(2): response = await client.post("/api/v1/users") assert response.status_code == 200 response = await client.post("/api/v1/users") assert response.status_code == 429 async def test_api_key_based_limiting(self, client: AsyncClient) -> None: """Test rate limiting by API key.""" for _ in range(5): response = await client.get( "/api/v1/premium", headers={"X-API-Key": "key-a"} ) assert response.status_code == 200 for _ in range(5): response = await client.get( "/api/v1/premium", headers={"X-API-Key": "key-b"} ) assert response.status_code == 200 async def test_basic_rate_limiting_works(self, client: AsyncClient) -> None: """Test that basic rate limiting is functional.""" # Make a request and verify it works response = await client.get("/api/v1/users/1") assert response.status_code == 200 data = response.json() assert data["id"] == 1 async def test_retry_after_in_429_response(self, client: AsyncClient) -> None: """Test that 429 responses include Retry-After header.""" for _ in range(10): await client.get("/api/v1/users") response = await client.get("/api/v1/users") assert response.status_code == 429 assert "Retry-After" in response.headers data = response.json() assert data["error"] == "rate_limit_exceeded" assert data["retry_after"] is not None class TestMixedDecoratorAndMiddleware: """Test combining decorator and middleware rate limiting.""" @pytest.fixture async def mixed_app(self) -> AsyncGenerator[FastAPI, None]: """Create app with both middleware and decorator limiting.""" backend = MemoryBackend() limiter = RateLimiter(backend) @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: await limiter.initialize() set_limiter(limiter) yield await limiter.close() app = FastAPI(lifespan=lifespan) app.add_middleware( RateLimitMiddleware, limit=20, window_size=60, backend=backend, exempt_paths={"/health"}, key_prefix="global", ) @app.exception_handler(RateLimitExceeded) async def handler(request: Request, exc: RateLimitExceeded) -> JSONResponse: return JSONResponse(status_code=429, content={"detail": exc.message}) @app.get("/health") async def health() -> dict[str, str]: return {"status": "healthy"} @app.get("/api/strict") @rate_limit(3, 60) async def strict_endpoint(request: Request) -> dict[str, str]: return {"status": "ok"} @app.get("/api/normal") async def normal_endpoint() -> dict[str, str]: return {"status": "ok"} yield app @pytest.fixture async def client(self, mixed_app: FastAPI) -> AsyncGenerator[AsyncClient, None]: """Create test client.""" transport = ASGITransport(app=mixed_app) async with AsyncClient(transport=transport, base_url="http://test") as client: yield client async def test_health_bypasses_middleware(self, client: AsyncClient) -> None: """Test that health endpoint bypasses middleware limiting.""" for _ in range(30): response = await client.get("/health") assert response.status_code == 200 async def test_decorator_limit_stricter_than_middleware( self, client: AsyncClient ) -> None: """Test that decorator limit is enforced before middleware limit.""" for _ in range(3): response = await client.get("/api/strict") assert response.status_code == 200 response = await client.get("/api/strict") assert response.status_code == 429 async def test_middleware_limit_applies_to_normal_endpoints( self, client: AsyncClient ) -> None: """Test that middleware limit applies to non-decorated endpoints.""" for _ in range(20): response = await client.get("/api/normal") assert response.status_code == 200 response = await client.get("/api/normal") assert response.status_code == 429 class TestConcurrentRequests: """Test rate limiting under concurrent load.""" @pytest.fixture async def concurrent_app(self) -> AsyncGenerator[FastAPI, None]: """Create app for concurrent testing.""" backend = MemoryBackend() limiter = RateLimiter(backend) @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: await limiter.initialize() set_limiter(limiter) yield await limiter.close() app = FastAPI(lifespan=lifespan) @app.exception_handler(RateLimitExceeded) async def handler(request: Request, exc: RateLimitExceeded) -> JSONResponse: return JSONResponse(status_code=429, content={"detail": exc.message}) @app.get("/api/resource") @rate_limit(10, 60) async def resource(request: Request) -> dict[str, str]: await asyncio.sleep(0.01) return {"status": "ok"} yield app async def test_concurrent_requests_respect_limit( self, concurrent_app: FastAPI ) -> None: """Test that concurrent requests respect rate limit.""" transport = ASGITransport(app=concurrent_app) async with AsyncClient(transport=transport, base_url="http://test") as client: async def make_request() -> int: response = await client.get("/api/resource") return response.status_code results = await asyncio.gather(*[make_request() for _ in range(15)]) success_count = sum(1 for r in results if r == 200) rate_limited_count = sum(1 for r in results if r == 429) assert success_count == 10 assert rate_limited_count == 5 class TestLimiterStateManagement: """Test RateLimiter state management.""" async def test_limiter_reset_clears_state(self) -> None: """Test that reset clears rate limit state.""" backend = MemoryBackend() limiter = RateLimiter(backend) await limiter.initialize() try: config = RateLimitConfig(limit=3, window_size=60) class MockRequest: def __init__(self) -> None: self.url = type("URL", (), {"path": "/test"})() self.method = "GET" self.client = type("Client", (), {"host": "127.0.0.1"})() self.headers: dict[str, str] = {} request = MockRequest() for _ in range(3): result = await limiter.check(request, config) # type: ignore[arg-type] assert result.allowed result = await limiter.check(request, config) # type: ignore[arg-type] assert not result.allowed await limiter.reset(request, config) # type: ignore[arg-type] result = await limiter.check(request, config) # type: ignore[arg-type] assert result.allowed finally: await limiter.close() async def test_get_state_returns_current_info(self) -> None: """Test that get_state returns current rate limit info.""" backend = MemoryBackend() limiter = RateLimiter(backend) await limiter.initialize() try: config = RateLimitConfig(limit=5, window_size=60) class MockRequest: def __init__(self) -> None: self.url = type("URL", (), {"path": "/test"})() self.method = "GET" self.client = type("Client", (), {"host": "127.0.0.1"})() self.headers: dict[str, str] = {} request = MockRequest() await limiter.check(request, config) # type: ignore[arg-type] await limiter.check(request, config) # type: ignore[arg-type] state = await limiter.get_state(request, config) # type: ignore[arg-type] assert state is not None assert state.remaining == 3 finally: await limiter.close() class TestMultipleAlgorithms: """Test different algorithms in the same application.""" @pytest.fixture async def multi_algo_app(self) -> AsyncGenerator[FastAPI, None]: """Create app with multiple algorithms.""" backend = MemoryBackend() limiter = RateLimiter(backend) @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: await limiter.initialize() set_limiter(limiter) yield await limiter.close() app = FastAPI(lifespan=lifespan) @app.exception_handler(RateLimitExceeded) async def handler(request: Request, exc: RateLimitExceeded) -> JSONResponse: return JSONResponse(status_code=429, content={"detail": exc.message}) @app.get("/sliding-window") @rate_limit(5, window_size=60, algorithm=Algorithm.SLIDING_WINDOW) async def sliding_window(request: Request) -> dict[str, str]: return {"algorithm": "sliding_window"} @app.get("/fixed-window") @rate_limit(5, window_size=60, algorithm=Algorithm.FIXED_WINDOW) async def fixed_window(request: Request) -> dict[str, str]: return {"algorithm": "fixed_window"} @app.get("/token-bucket") @rate_limit(5, window_size=60, algorithm=Algorithm.TOKEN_BUCKET) async def token_bucket(request: Request) -> dict[str, str]: return {"algorithm": "token_bucket"} yield app @pytest.fixture async def client( self, multi_algo_app: FastAPI ) -> AsyncGenerator[AsyncClient, None]: """Create test client.""" transport = ASGITransport(app=multi_algo_app) async with AsyncClient(transport=transport, base_url="http://test") as client: yield client async def test_all_algorithms_enforce_limits(self, client: AsyncClient) -> None: """Test that all algorithms enforce their limits.""" endpoints = ["/sliding-window", "/fixed-window", "/token-bucket"] for endpoint in endpoints: for i in range(5): response = await client.get(endpoint) assert response.status_code == 200, f"{endpoint} request {i} failed" response = await client.get(endpoint) assert response.status_code == 429, f"{endpoint} should be rate limited"