"""Integration tests for fastapi-traffic. End-to-end tests that verify the complete rate limiting flow across different configurations and usage patterns. """ from __future__ import annotations import asyncio from contextlib import asynccontextmanager from typing import AsyncGenerator import pytest from fastapi import FastAPI, Request from fastapi.responses import JSONResponse from httpx import ASGITransport, AsyncClient from fastapi_traffic import ( Algorithm, MemoryBackend, RateLimitExceeded, RateLimiter, rate_limit, ) from fastapi_traffic.core.config import RateLimitConfig from fastapi_traffic.core.limiter import set_limiter from fastapi_traffic.middleware import RateLimitMiddleware class TestFullApplicationFlow: """Integration tests for a complete application setup.""" @pytest.fixture async def full_app(self) -> AsyncGenerator[FastAPI, None]: """Create a fully configured application.""" backend = MemoryBackend() limiter = RateLimiter(backend) @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: await limiter.initialize() set_limiter(limiter) yield await limiter.close() app = FastAPI(lifespan=lifespan) @app.exception_handler(RateLimitExceeded) async def rate_limit_handler( request: Request, exc: RateLimitExceeded ) -> JSONResponse: return JSONResponse( status_code=429, content={ "error": "rate_limit_exceeded", "message": exc.message, "retry_after": exc.retry_after, }, headers=exc.limit_info.to_headers() if exc.limit_info else {}, ) @app.get("/api/v1/users") @rate_limit(10, 60) async def list_users(request: Request) -> dict[str, object]: return {"users": [], "count": 0} @app.get("/api/v1/users/{user_id}") @rate_limit(20, 60) async def get_user(request: Request, user_id: int) -> dict[str, object]: return {"id": user_id, "name": f"User {user_id}"} @app.post("/api/v1/users") @rate_limit(5, window_size=60, cost=2) async def create_user(request: Request) -> dict[str, object]: return {"id": 1, "created": True} def get_api_key(request: Request) -> str: return request.headers.get("X-API-Key", "anonymous") @app.get("/api/v1/premium") @rate_limit(100, window_size=60, key_extractor=get_api_key) async def premium_endpoint(request: Request) -> dict[str, str]: return {"tier": "premium"} yield app @pytest.fixture async def client(self, full_app: FastAPI) -> AsyncGenerator[AsyncClient, None]: """Create test client with lifespan.""" transport = ASGITransport(app=full_app) async with AsyncClient(transport=transport, base_url="http://test") as client: yield client async def test_different_endpoints_have_separate_limits( self, client: AsyncClient ) -> None: """Test that different endpoints maintain separate rate limits.""" for _ in range(10): response = await client.get("/api/v1/users") assert response.status_code == 200 response = await client.get("/api/v1/users") assert response.status_code == 429 response = await client.get("/api/v1/users/1") assert response.status_code == 200 async def test_cost_based_limiting(self, client: AsyncClient) -> None: """Test that cost parameter affects rate limiting.""" for _ in range(2): response = await client.post("/api/v1/users") assert response.status_code == 200 response = await client.post("/api/v1/users") assert response.status_code == 429 async def test_api_key_based_limiting(self, client: AsyncClient) -> None: """Test rate limiting by API key.""" for _ in range(5): response = await client.get( "/api/v1/premium", headers={"X-API-Key": "key-a"} ) assert response.status_code == 200 for _ in range(5): response = await client.get( "/api/v1/premium", headers={"X-API-Key": "key-b"} ) assert response.status_code == 200 async def test_basic_rate_limiting_works( self, client: AsyncClient ) -> None: """Test that basic rate limiting is functional.""" # Make a request and verify it works response = await client.get("/api/v1/users/1") assert response.status_code == 200 data = response.json() assert data["id"] == 1 async def test_retry_after_in_429_response(self, client: AsyncClient) -> None: """Test that 429 responses include Retry-After header.""" for _ in range(10): await client.get("/api/v1/users") response = await client.get("/api/v1/users") assert response.status_code == 429 assert "Retry-After" in response.headers data = response.json() assert data["error"] == "rate_limit_exceeded" assert data["retry_after"] is not None class TestMixedDecoratorAndMiddleware: """Test combining decorator and middleware rate limiting.""" @pytest.fixture async def mixed_app(self) -> AsyncGenerator[FastAPI, None]: """Create app with both middleware and decorator limiting.""" backend = MemoryBackend() limiter = RateLimiter(backend) @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: await limiter.initialize() set_limiter(limiter) yield await limiter.close() app = FastAPI(lifespan=lifespan) app.add_middleware( RateLimitMiddleware, limit=20, window_size=60, backend=backend, exempt_paths={"/health"}, key_prefix="global", ) @app.exception_handler(RateLimitExceeded) async def handler(request: Request, exc: RateLimitExceeded) -> JSONResponse: return JSONResponse(status_code=429, content={"detail": exc.message}) @app.get("/health") async def health() -> dict[str, str]: return {"status": "healthy"} @app.get("/api/strict") @rate_limit(3, 60) async def strict_endpoint(request: Request) -> dict[str, str]: return {"status": "ok"} @app.get("/api/normal") async def normal_endpoint() -> dict[str, str]: return {"status": "ok"} yield app @pytest.fixture async def client(self, mixed_app: FastAPI) -> AsyncGenerator[AsyncClient, None]: """Create test client.""" transport = ASGITransport(app=mixed_app) async with AsyncClient(transport=transport, base_url="http://test") as client: yield client async def test_health_bypasses_middleware(self, client: AsyncClient) -> None: """Test that health endpoint bypasses middleware limiting.""" for _ in range(30): response = await client.get("/health") assert response.status_code == 200 async def test_decorator_limit_stricter_than_middleware( self, client: AsyncClient ) -> None: """Test that decorator limit is enforced before middleware limit.""" for _ in range(3): response = await client.get("/api/strict") assert response.status_code == 200 response = await client.get("/api/strict") assert response.status_code == 429 async def test_middleware_limit_applies_to_normal_endpoints( self, client: AsyncClient ) -> None: """Test that middleware limit applies to non-decorated endpoints.""" for _ in range(20): response = await client.get("/api/normal") assert response.status_code == 200 response = await client.get("/api/normal") assert response.status_code == 429 class TestConcurrentRequests: """Test rate limiting under concurrent load.""" @pytest.fixture async def concurrent_app(self) -> AsyncGenerator[FastAPI, None]: """Create app for concurrent testing.""" backend = MemoryBackend() limiter = RateLimiter(backend) @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: await limiter.initialize() set_limiter(limiter) yield await limiter.close() app = FastAPI(lifespan=lifespan) @app.exception_handler(RateLimitExceeded) async def handler(request: Request, exc: RateLimitExceeded) -> JSONResponse: return JSONResponse(status_code=429, content={"detail": exc.message}) @app.get("/api/resource") @rate_limit(10, 60) async def resource(request: Request) -> dict[str, str]: await asyncio.sleep(0.01) return {"status": "ok"} yield app async def test_concurrent_requests_respect_limit( self, concurrent_app: FastAPI ) -> None: """Test that concurrent requests respect rate limit.""" transport = ASGITransport(app=concurrent_app) async with AsyncClient(transport=transport, base_url="http://test") as client: async def make_request() -> int: response = await client.get("/api/resource") return response.status_code results = await asyncio.gather(*[make_request() for _ in range(15)]) success_count = sum(1 for r in results if r == 200) rate_limited_count = sum(1 for r in results if r == 429) assert success_count == 10 assert rate_limited_count == 5 class TestLimiterStateManagement: """Test RateLimiter state management.""" async def test_limiter_reset_clears_state(self) -> None: """Test that reset clears rate limit state.""" backend = MemoryBackend() limiter = RateLimiter(backend) await limiter.initialize() try: config = RateLimitConfig(limit=3, window_size=60) class MockRequest: def __init__(self) -> None: self.url = type("URL", (), {"path": "/test"})() self.method = "GET" self.client = type("Client", (), {"host": "127.0.0.1"})() self.headers: dict[str, str] = {} request = MockRequest() for _ in range(3): result = await limiter.check(request, config) # type: ignore[arg-type] assert result.allowed result = await limiter.check(request, config) # type: ignore[arg-type] assert not result.allowed await limiter.reset(request, config) # type: ignore[arg-type] result = await limiter.check(request, config) # type: ignore[arg-type] assert result.allowed finally: await limiter.close() async def test_get_state_returns_current_info(self) -> None: """Test that get_state returns current rate limit info.""" backend = MemoryBackend() limiter = RateLimiter(backend) await limiter.initialize() try: config = RateLimitConfig(limit=5, window_size=60) class MockRequest: def __init__(self) -> None: self.url = type("URL", (), {"path": "/test"})() self.method = "GET" self.client = type("Client", (), {"host": "127.0.0.1"})() self.headers: dict[str, str] = {} request = MockRequest() await limiter.check(request, config) # type: ignore[arg-type] await limiter.check(request, config) # type: ignore[arg-type] state = await limiter.get_state(request, config) # type: ignore[arg-type] assert state is not None assert state.remaining == 3 finally: await limiter.close() class TestMultipleAlgorithms: """Test different algorithms in the same application.""" @pytest.fixture async def multi_algo_app(self) -> AsyncGenerator[FastAPI, None]: """Create app with multiple algorithms.""" backend = MemoryBackend() limiter = RateLimiter(backend) @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: await limiter.initialize() set_limiter(limiter) yield await limiter.close() app = FastAPI(lifespan=lifespan) @app.exception_handler(RateLimitExceeded) async def handler(request: Request, exc: RateLimitExceeded) -> JSONResponse: return JSONResponse(status_code=429, content={"detail": exc.message}) @app.get("/sliding-window") @rate_limit(5, window_size=60, algorithm=Algorithm.SLIDING_WINDOW) async def sliding_window(request: Request) -> dict[str, str]: return {"algorithm": "sliding_window"} @app.get("/fixed-window") @rate_limit(5, window_size=60, algorithm=Algorithm.FIXED_WINDOW) async def fixed_window(request: Request) -> dict[str, str]: return {"algorithm": "fixed_window"} @app.get("/token-bucket") @rate_limit(5, window_size=60, algorithm=Algorithm.TOKEN_BUCKET) async def token_bucket(request: Request) -> dict[str, str]: return {"algorithm": "token_bucket"} yield app @pytest.fixture async def client( self, multi_algo_app: FastAPI ) -> AsyncGenerator[AsyncClient, None]: """Create test client.""" transport = ASGITransport(app=multi_algo_app) async with AsyncClient(transport=transport, base_url="http://test") as client: yield client async def test_all_algorithms_enforce_limits(self, client: AsyncClient) -> None: """Test that all algorithms enforce their limits.""" endpoints = ["/sliding-window", "/fixed-window", "/token-bucket"] for endpoint in endpoints: for i in range(5): response = await client.get(endpoint) assert response.status_code == 200, f"{endpoint} request {i} failed" response = await client.get(endpoint) assert response.status_code == 429, f"{endpoint} should be rate limited"