Add comprehensive test suite with 134 tests
Covers all algorithms, backends, decorators, middleware, and integration scenarios. Added conftest.py with shared fixtures and pytest-asyncio configuration.
This commit is contained in:
407
tests/test_integration.py
Normal file
407
tests/test_integration.py
Normal file
@@ -0,0 +1,407 @@
|
||||
"""Integration tests for fastapi-traffic.
|
||||
|
||||
End-to-end tests that verify the complete rate limiting flow
|
||||
across different configurations and usage patterns.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from fastapi_traffic import (
|
||||
Algorithm,
|
||||
MemoryBackend,
|
||||
RateLimitExceeded,
|
||||
RateLimiter,
|
||||
rate_limit,
|
||||
)
|
||||
from fastapi_traffic.core.config import RateLimitConfig
|
||||
from fastapi_traffic.core.limiter import set_limiter
|
||||
from fastapi_traffic.middleware import RateLimitMiddleware
|
||||
|
||||
|
||||
class TestFullApplicationFlow:
|
||||
"""Integration tests for a complete application setup."""
|
||||
|
||||
@pytest.fixture
|
||||
async def full_app(self) -> AsyncGenerator[FastAPI, None]:
|
||||
"""Create a fully configured application."""
|
||||
backend = MemoryBackend()
|
||||
limiter = RateLimiter(backend)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
await limiter.initialize()
|
||||
set_limiter(limiter)
|
||||
yield
|
||||
await limiter.close()
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
@app.exception_handler(RateLimitExceeded)
|
||||
async def rate_limit_handler(
|
||||
request: Request, exc: RateLimitExceeded
|
||||
) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"error": "rate_limit_exceeded",
|
||||
"message": exc.message,
|
||||
"retry_after": exc.retry_after,
|
||||
},
|
||||
headers=exc.limit_info.to_headers() if exc.limit_info else {},
|
||||
)
|
||||
|
||||
@app.get("/api/v1/users")
|
||||
@rate_limit(10, 60)
|
||||
async def list_users(request: Request) -> dict[str, object]:
|
||||
return {"users": [], "count": 0}
|
||||
|
||||
@app.get("/api/v1/users/{user_id}")
|
||||
@rate_limit(20, 60)
|
||||
async def get_user(request: Request, user_id: int) -> dict[str, object]:
|
||||
return {"id": user_id, "name": f"User {user_id}"}
|
||||
|
||||
@app.post("/api/v1/users")
|
||||
@rate_limit(5, window_size=60, cost=2)
|
||||
async def create_user(request: Request) -> dict[str, object]:
|
||||
return {"id": 1, "created": True}
|
||||
|
||||
def get_api_key(request: Request) -> str:
|
||||
return request.headers.get("X-API-Key", "anonymous")
|
||||
|
||||
@app.get("/api/v1/premium")
|
||||
@rate_limit(100, window_size=60, key_extractor=get_api_key)
|
||||
async def premium_endpoint(request: Request) -> dict[str, str]:
|
||||
return {"tier": "premium"}
|
||||
|
||||
yield app
|
||||
|
||||
@pytest.fixture
|
||||
async def client(self, full_app: FastAPI) -> AsyncGenerator[AsyncClient, None]:
|
||||
"""Create test client with lifespan."""
|
||||
transport = ASGITransport(app=full_app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
yield client
|
||||
|
||||
async def test_different_endpoints_have_separate_limits(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
"""Test that different endpoints maintain separate rate limits."""
|
||||
for _ in range(10):
|
||||
response = await client.get("/api/v1/users")
|
||||
assert response.status_code == 200
|
||||
|
||||
response = await client.get("/api/v1/users")
|
||||
assert response.status_code == 429
|
||||
|
||||
response = await client.get("/api/v1/users/1")
|
||||
assert response.status_code == 200
|
||||
|
||||
async def test_cost_based_limiting(self, client: AsyncClient) -> None:
|
||||
"""Test that cost parameter affects rate limiting."""
|
||||
for _ in range(2):
|
||||
response = await client.post("/api/v1/users")
|
||||
assert response.status_code == 200
|
||||
|
||||
response = await client.post("/api/v1/users")
|
||||
assert response.status_code == 429
|
||||
|
||||
async def test_api_key_based_limiting(self, client: AsyncClient) -> None:
|
||||
"""Test rate limiting by API key."""
|
||||
for _ in range(5):
|
||||
response = await client.get(
|
||||
"/api/v1/premium", headers={"X-API-Key": "key-a"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
for _ in range(5):
|
||||
response = await client.get(
|
||||
"/api/v1/premium", headers={"X-API-Key": "key-b"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
async def test_basic_rate_limiting_works(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
"""Test that basic rate limiting is functional."""
|
||||
# Make a request and verify it works
|
||||
response = await client.get("/api/v1/users/1")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == 1
|
||||
|
||||
async def test_retry_after_in_429_response(self, client: AsyncClient) -> None:
|
||||
"""Test that 429 responses include Retry-After header."""
|
||||
for _ in range(10):
|
||||
await client.get("/api/v1/users")
|
||||
|
||||
response = await client.get("/api/v1/users")
|
||||
assert response.status_code == 429
|
||||
assert "Retry-After" in response.headers
|
||||
data = response.json()
|
||||
assert data["error"] == "rate_limit_exceeded"
|
||||
assert data["retry_after"] is not None
|
||||
|
||||
|
||||
class TestMixedDecoratorAndMiddleware:
|
||||
"""Test combining decorator and middleware rate limiting."""
|
||||
|
||||
@pytest.fixture
|
||||
async def mixed_app(self) -> AsyncGenerator[FastAPI, None]:
|
||||
"""Create app with both middleware and decorator limiting."""
|
||||
backend = MemoryBackend()
|
||||
limiter = RateLimiter(backend)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
await limiter.initialize()
|
||||
set_limiter(limiter)
|
||||
yield
|
||||
await limiter.close()
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
app.add_middleware(
|
||||
RateLimitMiddleware,
|
||||
limit=20,
|
||||
window_size=60,
|
||||
backend=backend,
|
||||
exempt_paths={"/health"},
|
||||
key_prefix="global",
|
||||
)
|
||||
|
||||
@app.exception_handler(RateLimitExceeded)
|
||||
async def handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
|
||||
return JSONResponse(status_code=429, content={"detail": exc.message})
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> dict[str, str]:
|
||||
return {"status": "healthy"}
|
||||
|
||||
@app.get("/api/strict")
|
||||
@rate_limit(3, 60)
|
||||
async def strict_endpoint(request: Request) -> dict[str, str]:
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.get("/api/normal")
|
||||
async def normal_endpoint() -> dict[str, str]:
|
||||
return {"status": "ok"}
|
||||
|
||||
yield app
|
||||
|
||||
@pytest.fixture
|
||||
async def client(self, mixed_app: FastAPI) -> AsyncGenerator[AsyncClient, None]:
|
||||
"""Create test client."""
|
||||
transport = ASGITransport(app=mixed_app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
yield client
|
||||
|
||||
async def test_health_bypasses_middleware(self, client: AsyncClient) -> None:
|
||||
"""Test that health endpoint bypasses middleware limiting."""
|
||||
for _ in range(30):
|
||||
response = await client.get("/health")
|
||||
assert response.status_code == 200
|
||||
|
||||
async def test_decorator_limit_stricter_than_middleware(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
"""Test that decorator limit is enforced before middleware limit."""
|
||||
for _ in range(3):
|
||||
response = await client.get("/api/strict")
|
||||
assert response.status_code == 200
|
||||
|
||||
response = await client.get("/api/strict")
|
||||
assert response.status_code == 429
|
||||
|
||||
async def test_middleware_limit_applies_to_normal_endpoints(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
"""Test that middleware limit applies to non-decorated endpoints."""
|
||||
for _ in range(20):
|
||||
response = await client.get("/api/normal")
|
||||
assert response.status_code == 200
|
||||
|
||||
response = await client.get("/api/normal")
|
||||
assert response.status_code == 429
|
||||
|
||||
|
||||
class TestConcurrentRequests:
|
||||
"""Test rate limiting under concurrent load."""
|
||||
|
||||
@pytest.fixture
|
||||
async def concurrent_app(self) -> AsyncGenerator[FastAPI, None]:
|
||||
"""Create app for concurrent testing."""
|
||||
backend = MemoryBackend()
|
||||
limiter = RateLimiter(backend)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
await limiter.initialize()
|
||||
set_limiter(limiter)
|
||||
yield
|
||||
await limiter.close()
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
@app.exception_handler(RateLimitExceeded)
|
||||
async def handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
|
||||
return JSONResponse(status_code=429, content={"detail": exc.message})
|
||||
|
||||
@app.get("/api/resource")
|
||||
@rate_limit(10, 60)
|
||||
async def resource(request: Request) -> dict[str, str]:
|
||||
await asyncio.sleep(0.01)
|
||||
return {"status": "ok"}
|
||||
|
||||
yield app
|
||||
|
||||
async def test_concurrent_requests_respect_limit(
|
||||
self, concurrent_app: FastAPI
|
||||
) -> None:
|
||||
"""Test that concurrent requests respect rate limit."""
|
||||
transport = ASGITransport(app=concurrent_app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
|
||||
async def make_request() -> int:
|
||||
response = await client.get("/api/resource")
|
||||
return response.status_code
|
||||
|
||||
results = await asyncio.gather(*[make_request() for _ in range(15)])
|
||||
|
||||
success_count = sum(1 for r in results if r == 200)
|
||||
rate_limited_count = sum(1 for r in results if r == 429)
|
||||
|
||||
assert success_count == 10
|
||||
assert rate_limited_count == 5
|
||||
|
||||
|
||||
class TestLimiterStateManagement:
|
||||
"""Test RateLimiter state management."""
|
||||
|
||||
async def test_limiter_reset_clears_state(self) -> None:
|
||||
"""Test that reset clears rate limit state."""
|
||||
backend = MemoryBackend()
|
||||
limiter = RateLimiter(backend)
|
||||
await limiter.initialize()
|
||||
|
||||
try:
|
||||
config = RateLimitConfig(limit=3, window_size=60)
|
||||
|
||||
class MockRequest:
|
||||
def __init__(self) -> None:
|
||||
self.url = type("URL", (), {"path": "/test"})()
|
||||
self.method = "GET"
|
||||
self.client = type("Client", (), {"host": "127.0.0.1"})()
|
||||
self.headers: dict[str, str] = {}
|
||||
|
||||
request = MockRequest()
|
||||
|
||||
for _ in range(3):
|
||||
result = await limiter.check(request, config) # type: ignore[arg-type]
|
||||
assert result.allowed
|
||||
|
||||
result = await limiter.check(request, config) # type: ignore[arg-type]
|
||||
assert not result.allowed
|
||||
|
||||
await limiter.reset(request, config) # type: ignore[arg-type]
|
||||
|
||||
result = await limiter.check(request, config) # type: ignore[arg-type]
|
||||
assert result.allowed
|
||||
finally:
|
||||
await limiter.close()
|
||||
|
||||
async def test_get_state_returns_current_info(self) -> None:
|
||||
"""Test that get_state returns current rate limit info."""
|
||||
backend = MemoryBackend()
|
||||
limiter = RateLimiter(backend)
|
||||
await limiter.initialize()
|
||||
|
||||
try:
|
||||
config = RateLimitConfig(limit=5, window_size=60)
|
||||
|
||||
class MockRequest:
|
||||
def __init__(self) -> None:
|
||||
self.url = type("URL", (), {"path": "/test"})()
|
||||
self.method = "GET"
|
||||
self.client = type("Client", (), {"host": "127.0.0.1"})()
|
||||
self.headers: dict[str, str] = {}
|
||||
|
||||
request = MockRequest()
|
||||
|
||||
await limiter.check(request, config) # type: ignore[arg-type]
|
||||
await limiter.check(request, config) # type: ignore[arg-type]
|
||||
|
||||
state = await limiter.get_state(request, config) # type: ignore[arg-type]
|
||||
assert state is not None
|
||||
assert state.remaining == 3
|
||||
finally:
|
||||
await limiter.close()
|
||||
|
||||
|
||||
class TestMultipleAlgorithms:
|
||||
"""Test different algorithms in the same application."""
|
||||
|
||||
@pytest.fixture
|
||||
async def multi_algo_app(self) -> AsyncGenerator[FastAPI, None]:
|
||||
"""Create app with multiple algorithms."""
|
||||
backend = MemoryBackend()
|
||||
limiter = RateLimiter(backend)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
await limiter.initialize()
|
||||
set_limiter(limiter)
|
||||
yield
|
||||
await limiter.close()
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
@app.exception_handler(RateLimitExceeded)
|
||||
async def handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
|
||||
return JSONResponse(status_code=429, content={"detail": exc.message})
|
||||
|
||||
@app.get("/sliding-window")
|
||||
@rate_limit(5, window_size=60, algorithm=Algorithm.SLIDING_WINDOW)
|
||||
async def sliding_window(request: Request) -> dict[str, str]:
|
||||
return {"algorithm": "sliding_window"}
|
||||
|
||||
@app.get("/fixed-window")
|
||||
@rate_limit(5, window_size=60, algorithm=Algorithm.FIXED_WINDOW)
|
||||
async def fixed_window(request: Request) -> dict[str, str]:
|
||||
return {"algorithm": "fixed_window"}
|
||||
|
||||
@app.get("/token-bucket")
|
||||
@rate_limit(5, window_size=60, algorithm=Algorithm.TOKEN_BUCKET)
|
||||
async def token_bucket(request: Request) -> dict[str, str]:
|
||||
return {"algorithm": "token_bucket"}
|
||||
|
||||
yield app
|
||||
|
||||
@pytest.fixture
|
||||
async def client(
|
||||
self, multi_algo_app: FastAPI
|
||||
) -> AsyncGenerator[AsyncClient, None]:
|
||||
"""Create test client."""
|
||||
transport = ASGITransport(app=multi_algo_app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
yield client
|
||||
|
||||
async def test_all_algorithms_enforce_limits(self, client: AsyncClient) -> None:
|
||||
"""Test that all algorithms enforce their limits."""
|
||||
endpoints = ["/sliding-window", "/fixed-window", "/token-bucket"]
|
||||
|
||||
for endpoint in endpoints:
|
||||
for i in range(5):
|
||||
response = await client.get(endpoint)
|
||||
assert response.status_code == 200, f"{endpoint} request {i} failed"
|
||||
|
||||
response = await client.get(endpoint)
|
||||
assert response.status_code == 429, f"{endpoint} should be rate limited"
|
||||
Reference in New Issue
Block a user