Files
fastapi-traffic/tests/test_integration.py
zanewalker dfaa0aaec4 Add comprehensive test suite with 134 tests
Covers all algorithms, backends, decorators, middleware, and integration
scenarios. Added conftest.py with shared fixtures and pytest-asyncio
configuration.
2026-01-09 00:50:25 +00:00

408 lines
14 KiB
Python

"""Integration tests for fastapi-traffic.
End-to-end tests that verify the complete rate limiting flow
across different configurations and usage patterns.
"""
from __future__ import annotations
import asyncio
from contextlib import asynccontextmanager
from typing import AsyncGenerator
import pytest
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from httpx import ASGITransport, AsyncClient
from fastapi_traffic import (
Algorithm,
MemoryBackend,
RateLimitExceeded,
RateLimiter,
rate_limit,
)
from fastapi_traffic.core.config import RateLimitConfig
from fastapi_traffic.core.limiter import set_limiter
from fastapi_traffic.middleware import RateLimitMiddleware
class TestFullApplicationFlow:
"""Integration tests for a complete application setup."""
@pytest.fixture
async def full_app(self) -> AsyncGenerator[FastAPI, None]:
"""Create a fully configured application."""
backend = MemoryBackend()
limiter = RateLimiter(backend)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
await limiter.initialize()
set_limiter(limiter)
yield
await limiter.close()
app = FastAPI(lifespan=lifespan)
@app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(
request: Request, exc: RateLimitExceeded
) -> JSONResponse:
return JSONResponse(
status_code=429,
content={
"error": "rate_limit_exceeded",
"message": exc.message,
"retry_after": exc.retry_after,
},
headers=exc.limit_info.to_headers() if exc.limit_info else {},
)
@app.get("/api/v1/users")
@rate_limit(10, 60)
async def list_users(request: Request) -> dict[str, object]:
return {"users": [], "count": 0}
@app.get("/api/v1/users/{user_id}")
@rate_limit(20, 60)
async def get_user(request: Request, user_id: int) -> dict[str, object]:
return {"id": user_id, "name": f"User {user_id}"}
@app.post("/api/v1/users")
@rate_limit(5, window_size=60, cost=2)
async def create_user(request: Request) -> dict[str, object]:
return {"id": 1, "created": True}
def get_api_key(request: Request) -> str:
return request.headers.get("X-API-Key", "anonymous")
@app.get("/api/v1/premium")
@rate_limit(100, window_size=60, key_extractor=get_api_key)
async def premium_endpoint(request: Request) -> dict[str, str]:
return {"tier": "premium"}
yield app
@pytest.fixture
async def client(self, full_app: FastAPI) -> AsyncGenerator[AsyncClient, None]:
"""Create test client with lifespan."""
transport = ASGITransport(app=full_app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
async def test_different_endpoints_have_separate_limits(
self, client: AsyncClient
) -> None:
"""Test that different endpoints maintain separate rate limits."""
for _ in range(10):
response = await client.get("/api/v1/users")
assert response.status_code == 200
response = await client.get("/api/v1/users")
assert response.status_code == 429
response = await client.get("/api/v1/users/1")
assert response.status_code == 200
async def test_cost_based_limiting(self, client: AsyncClient) -> None:
"""Test that cost parameter affects rate limiting."""
for _ in range(2):
response = await client.post("/api/v1/users")
assert response.status_code == 200
response = await client.post("/api/v1/users")
assert response.status_code == 429
async def test_api_key_based_limiting(self, client: AsyncClient) -> None:
"""Test rate limiting by API key."""
for _ in range(5):
response = await client.get(
"/api/v1/premium", headers={"X-API-Key": "key-a"}
)
assert response.status_code == 200
for _ in range(5):
response = await client.get(
"/api/v1/premium", headers={"X-API-Key": "key-b"}
)
assert response.status_code == 200
async def test_basic_rate_limiting_works(
self, client: AsyncClient
) -> None:
"""Test that basic rate limiting is functional."""
# Make a request and verify it works
response = await client.get("/api/v1/users/1")
assert response.status_code == 200
data = response.json()
assert data["id"] == 1
async def test_retry_after_in_429_response(self, client: AsyncClient) -> None:
"""Test that 429 responses include Retry-After header."""
for _ in range(10):
await client.get("/api/v1/users")
response = await client.get("/api/v1/users")
assert response.status_code == 429
assert "Retry-After" in response.headers
data = response.json()
assert data["error"] == "rate_limit_exceeded"
assert data["retry_after"] is not None
class TestMixedDecoratorAndMiddleware:
"""Test combining decorator and middleware rate limiting."""
@pytest.fixture
async def mixed_app(self) -> AsyncGenerator[FastAPI, None]:
"""Create app with both middleware and decorator limiting."""
backend = MemoryBackend()
limiter = RateLimiter(backend)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
await limiter.initialize()
set_limiter(limiter)
yield
await limiter.close()
app = FastAPI(lifespan=lifespan)
app.add_middleware(
RateLimitMiddleware,
limit=20,
window_size=60,
backend=backend,
exempt_paths={"/health"},
key_prefix="global",
)
@app.exception_handler(RateLimitExceeded)
async def handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
return JSONResponse(status_code=429, content={"detail": exc.message})
@app.get("/health")
async def health() -> dict[str, str]:
return {"status": "healthy"}
@app.get("/api/strict")
@rate_limit(3, 60)
async def strict_endpoint(request: Request) -> dict[str, str]:
return {"status": "ok"}
@app.get("/api/normal")
async def normal_endpoint() -> dict[str, str]:
return {"status": "ok"}
yield app
@pytest.fixture
async def client(self, mixed_app: FastAPI) -> AsyncGenerator[AsyncClient, None]:
"""Create test client."""
transport = ASGITransport(app=mixed_app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
async def test_health_bypasses_middleware(self, client: AsyncClient) -> None:
"""Test that health endpoint bypasses middleware limiting."""
for _ in range(30):
response = await client.get("/health")
assert response.status_code == 200
async def test_decorator_limit_stricter_than_middleware(
self, client: AsyncClient
) -> None:
"""Test that decorator limit is enforced before middleware limit."""
for _ in range(3):
response = await client.get("/api/strict")
assert response.status_code == 200
response = await client.get("/api/strict")
assert response.status_code == 429
async def test_middleware_limit_applies_to_normal_endpoints(
self, client: AsyncClient
) -> None:
"""Test that middleware limit applies to non-decorated endpoints."""
for _ in range(20):
response = await client.get("/api/normal")
assert response.status_code == 200
response = await client.get("/api/normal")
assert response.status_code == 429
class TestConcurrentRequests:
"""Test rate limiting under concurrent load."""
@pytest.fixture
async def concurrent_app(self) -> AsyncGenerator[FastAPI, None]:
"""Create app for concurrent testing."""
backend = MemoryBackend()
limiter = RateLimiter(backend)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
await limiter.initialize()
set_limiter(limiter)
yield
await limiter.close()
app = FastAPI(lifespan=lifespan)
@app.exception_handler(RateLimitExceeded)
async def handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
return JSONResponse(status_code=429, content={"detail": exc.message})
@app.get("/api/resource")
@rate_limit(10, 60)
async def resource(request: Request) -> dict[str, str]:
await asyncio.sleep(0.01)
return {"status": "ok"}
yield app
async def test_concurrent_requests_respect_limit(
self, concurrent_app: FastAPI
) -> None:
"""Test that concurrent requests respect rate limit."""
transport = ASGITransport(app=concurrent_app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
async def make_request() -> int:
response = await client.get("/api/resource")
return response.status_code
results = await asyncio.gather(*[make_request() for _ in range(15)])
success_count = sum(1 for r in results if r == 200)
rate_limited_count = sum(1 for r in results if r == 429)
assert success_count == 10
assert rate_limited_count == 5
class TestLimiterStateManagement:
"""Test RateLimiter state management."""
async def test_limiter_reset_clears_state(self) -> None:
"""Test that reset clears rate limit state."""
backend = MemoryBackend()
limiter = RateLimiter(backend)
await limiter.initialize()
try:
config = RateLimitConfig(limit=3, window_size=60)
class MockRequest:
def __init__(self) -> None:
self.url = type("URL", (), {"path": "/test"})()
self.method = "GET"
self.client = type("Client", (), {"host": "127.0.0.1"})()
self.headers: dict[str, str] = {}
request = MockRequest()
for _ in range(3):
result = await limiter.check(request, config) # type: ignore[arg-type]
assert result.allowed
result = await limiter.check(request, config) # type: ignore[arg-type]
assert not result.allowed
await limiter.reset(request, config) # type: ignore[arg-type]
result = await limiter.check(request, config) # type: ignore[arg-type]
assert result.allowed
finally:
await limiter.close()
async def test_get_state_returns_current_info(self) -> None:
"""Test that get_state returns current rate limit info."""
backend = MemoryBackend()
limiter = RateLimiter(backend)
await limiter.initialize()
try:
config = RateLimitConfig(limit=5, window_size=60)
class MockRequest:
def __init__(self) -> None:
self.url = type("URL", (), {"path": "/test"})()
self.method = "GET"
self.client = type("Client", (), {"host": "127.0.0.1"})()
self.headers: dict[str, str] = {}
request = MockRequest()
await limiter.check(request, config) # type: ignore[arg-type]
await limiter.check(request, config) # type: ignore[arg-type]
state = await limiter.get_state(request, config) # type: ignore[arg-type]
assert state is not None
assert state.remaining == 3
finally:
await limiter.close()
class TestMultipleAlgorithms:
"""Test different algorithms in the same application."""
@pytest.fixture
async def multi_algo_app(self) -> AsyncGenerator[FastAPI, None]:
"""Create app with multiple algorithms."""
backend = MemoryBackend()
limiter = RateLimiter(backend)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
await limiter.initialize()
set_limiter(limiter)
yield
await limiter.close()
app = FastAPI(lifespan=lifespan)
@app.exception_handler(RateLimitExceeded)
async def handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
return JSONResponse(status_code=429, content={"detail": exc.message})
@app.get("/sliding-window")
@rate_limit(5, window_size=60, algorithm=Algorithm.SLIDING_WINDOW)
async def sliding_window(request: Request) -> dict[str, str]:
return {"algorithm": "sliding_window"}
@app.get("/fixed-window")
@rate_limit(5, window_size=60, algorithm=Algorithm.FIXED_WINDOW)
async def fixed_window(request: Request) -> dict[str, str]:
return {"algorithm": "fixed_window"}
@app.get("/token-bucket")
@rate_limit(5, window_size=60, algorithm=Algorithm.TOKEN_BUCKET)
async def token_bucket(request: Request) -> dict[str, str]:
return {"algorithm": "token_bucket"}
yield app
@pytest.fixture
async def client(
self, multi_algo_app: FastAPI
) -> AsyncGenerator[AsyncClient, None]:
"""Create test client."""
transport = ASGITransport(app=multi_algo_app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
async def test_all_algorithms_enforce_limits(self, client: AsyncClient) -> None:
"""Test that all algorithms enforce their limits."""
endpoints = ["/sliding-window", "/fixed-window", "/token-bucket"]
for endpoint in endpoints:
for i in range(5):
response = await client.get(endpoint)
assert response.status_code == 200, f"{endpoint} request {i} failed"
response = await client.get(endpoint)
assert response.status_code == 429, f"{endpoint} should be rate limited"