Add comprehensive test suite with 134 tests
Covers all algorithms, backends, decorators, middleware, and integration scenarios. Added conftest.py with shared fixtures and pytest-asyncio configuration.
This commit is contained in:
191
tests/conftest.py
Normal file
191
tests/conftest.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""Shared fixtures and configuration for tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, Generator
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from fastapi_traffic import (
|
||||
Algorithm,
|
||||
MemoryBackend,
|
||||
RateLimitExceeded,
|
||||
RateLimiter,
|
||||
SQLiteBackend,
|
||||
rate_limit,
|
||||
)
|
||||
from fastapi_traffic.core.config import RateLimitConfig
|
||||
from fastapi_traffic.core.limiter import set_limiter
|
||||
from fastapi_traffic.middleware import RateLimitMiddleware
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop() -> Generator[asyncio.AbstractEventLoop, None, None]:
|
||||
"""Create an event loop for the test session."""
|
||||
loop = asyncio.new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def memory_backend() -> AsyncGenerator[MemoryBackend, None]:
|
||||
"""Create a fresh memory backend for each test."""
|
||||
backend = MemoryBackend(max_size=1000, cleanup_interval=60.0)
|
||||
yield backend
|
||||
await backend.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def sqlite_backend(tmp_path: object) -> AsyncGenerator[SQLiteBackend, None]:
|
||||
"""Create an in-memory SQLite backend for each test."""
|
||||
backend = SQLiteBackend(":memory:", cleanup_interval=60.0)
|
||||
await backend.initialize()
|
||||
yield backend
|
||||
await backend.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def limiter(memory_backend: MemoryBackend) -> AsyncGenerator[RateLimiter, None]:
|
||||
"""Create a rate limiter with memory backend."""
|
||||
limiter = RateLimiter(memory_backend)
|
||||
await limiter.initialize()
|
||||
set_limiter(limiter)
|
||||
yield limiter
|
||||
await limiter.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rate_limit_config() -> RateLimitConfig:
|
||||
"""Create a default rate limit config for testing."""
|
||||
return RateLimitConfig(
|
||||
limit=10,
|
||||
window_size=60.0,
|
||||
algorithm=Algorithm.SLIDING_WINDOW_COUNTER,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(limiter: RateLimiter) -> FastAPI:
|
||||
"""Create a FastAPI app with rate limiting configured."""
|
||||
app = FastAPI()
|
||||
|
||||
@app.exception_handler(RateLimitExceeded)
|
||||
async def rate_limit_handler(
|
||||
request: Request, exc: RateLimitExceeded
|
||||
) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"detail": exc.message,
|
||||
"retry_after": exc.retry_after,
|
||||
},
|
||||
headers=exc.limit_info.to_headers() if exc.limit_info else {},
|
||||
)
|
||||
|
||||
@app.get("/limited")
|
||||
@rate_limit(5, 60)
|
||||
async def limited_endpoint(request: Request) -> dict[str, str]:
|
||||
return {"message": "success"}
|
||||
|
||||
@app.get("/unlimited")
|
||||
async def unlimited_endpoint() -> dict[str, str]:
|
||||
return {"message": "no limit"}
|
||||
|
||||
def api_key_extractor(request: Request) -> str:
|
||||
return request.headers.get("X-API-Key", "anon")
|
||||
|
||||
@app.get("/custom-key")
|
||||
@rate_limit(5, window_size=60, key_extractor=api_key_extractor)
|
||||
async def custom_key_endpoint(request: Request) -> dict[str, str]:
|
||||
return {"message": "success"}
|
||||
|
||||
@app.get("/token-bucket")
|
||||
@rate_limit(10, window_size=60, algorithm=Algorithm.TOKEN_BUCKET, burst_size=5)
|
||||
async def token_bucket_endpoint(request: Request) -> dict[str, str]:
|
||||
return {"message": "success"}
|
||||
|
||||
@app.get("/high-cost")
|
||||
@rate_limit(10, window_size=60, cost=3)
|
||||
async def high_cost_endpoint(request: Request) -> dict[str, str]:
|
||||
return {"message": "success"}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client(app: FastAPI) -> AsyncGenerator[AsyncClient, None]:
|
||||
"""Create an async test client."""
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_with_middleware(memory_backend: MemoryBackend) -> FastAPI:
|
||||
"""Create a FastAPI app with rate limit middleware."""
|
||||
app = FastAPI()
|
||||
|
||||
app.add_middleware(
|
||||
RateLimitMiddleware,
|
||||
limit=10,
|
||||
window_size=60,
|
||||
backend=memory_backend,
|
||||
exempt_paths={"/health"},
|
||||
exempt_ips={"192.168.1.100"},
|
||||
)
|
||||
|
||||
@app.get("/api/resource")
|
||||
async def resource() -> dict[str, str]:
|
||||
return {"message": "success"}
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> dict[str, str]:
|
||||
return {"status": "ok"}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def middleware_client(
|
||||
app_with_middleware: FastAPI,
|
||||
) -> AsyncGenerator[AsyncClient, None]:
|
||||
"""Create an async test client for middleware tests."""
|
||||
transport = ASGITransport(app=app_with_middleware)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
yield client
|
||||
|
||||
|
||||
class MockRequest:
|
||||
"""Mock request object for unit tests."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str = "/test",
|
||||
method: str = "GET",
|
||||
client_host: str = "127.0.0.1",
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
self.url = type("URL", (), {"path": path})()
|
||||
self.method = method
|
||||
self.client = type("Client", (), {"host": client_host})()
|
||||
self._headers = headers or {}
|
||||
|
||||
@property
|
||||
def headers(self) -> dict[str, str]:
|
||||
return self._headers
|
||||
|
||||
def get(self, key: str, default: str | None = None) -> str | None:
|
||||
return self._headers.get(key, default)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request() -> MockRequest:
|
||||
"""Create a mock request for unit tests."""
|
||||
return MockRequest()
|
||||
Reference in New Issue
Block a user