"""Shared fixtures and configuration for tests.""" from __future__ import annotations import asyncio from typing import TYPE_CHECKING import pytest from fastapi import FastAPI, Request from fastapi.responses import JSONResponse from httpx import ASGITransport, AsyncClient from fastapi_traffic import ( Algorithm, MemoryBackend, RateLimiter, RateLimitExceeded, SQLiteBackend, rate_limit, ) from fastapi_traffic.core.config import RateLimitConfig from fastapi_traffic.core.limiter import set_limiter from fastapi_traffic.middleware import RateLimitMiddleware if TYPE_CHECKING: from collections.abc import AsyncGenerator, Generator pass @pytest.fixture(scope="session") def event_loop() -> Generator[asyncio.AbstractEventLoop, None, None]: """Create an event loop for the test session.""" loop = asyncio.new_event_loop() yield loop loop.close() @pytest.fixture async def memory_backend() -> AsyncGenerator[MemoryBackend, None]: """Create a fresh memory backend for each test.""" backend = MemoryBackend(max_size=1000, cleanup_interval=60.0) yield backend await backend.close() @pytest.fixture async def sqlite_backend(tmp_path: object) -> AsyncGenerator[SQLiteBackend, None]: """Create an in-memory SQLite backend for each test.""" backend = SQLiteBackend(":memory:", cleanup_interval=60.0) await backend.initialize() yield backend await backend.close() @pytest.fixture async def limiter(memory_backend: MemoryBackend) -> AsyncGenerator[RateLimiter, None]: """Create a rate limiter with memory backend.""" limiter = RateLimiter(memory_backend) await limiter.initialize() set_limiter(limiter) yield limiter await limiter.close() @pytest.fixture def rate_limit_config() -> RateLimitConfig: """Create a default rate limit config for testing.""" return RateLimitConfig( limit=10, window_size=60.0, algorithm=Algorithm.SLIDING_WINDOW_COUNTER, ) @pytest.fixture def app(limiter: RateLimiter) -> FastAPI: """Create a FastAPI app with rate limiting configured.""" app = FastAPI() @app.exception_handler(RateLimitExceeded) async def rate_limit_handler( request: Request, exc: RateLimitExceeded ) -> JSONResponse: return JSONResponse( status_code=429, content={ "detail": exc.message, "retry_after": exc.retry_after, }, headers=exc.limit_info.to_headers() if exc.limit_info else {}, ) @app.get("/limited") @rate_limit(5, 60) async def limited_endpoint(request: Request) -> dict[str, str]: return {"message": "success"} @app.get("/unlimited") async def unlimited_endpoint() -> dict[str, str]: return {"message": "no limit"} def api_key_extractor(request: Request) -> str: return request.headers.get("X-API-Key", "anon") @app.get("/custom-key") @rate_limit(5, window_size=60, key_extractor=api_key_extractor) async def custom_key_endpoint(request: Request) -> dict[str, str]: return {"message": "success"} @app.get("/token-bucket") @rate_limit(10, window_size=60, algorithm=Algorithm.TOKEN_BUCKET, burst_size=5) async def token_bucket_endpoint(request: Request) -> dict[str, str]: return {"message": "success"} @app.get("/high-cost") @rate_limit(10, window_size=60, cost=3) async def high_cost_endpoint(request: Request) -> dict[str, str]: return {"message": "success"} return app @pytest.fixture async def client(app: FastAPI) -> AsyncGenerator[AsyncClient, None]: """Create an async test client.""" transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: yield client @pytest.fixture def app_with_middleware(memory_backend: MemoryBackend) -> FastAPI: """Create a FastAPI app with rate limit middleware.""" app = FastAPI() app.add_middleware( RateLimitMiddleware, limit=10, window_size=60, backend=memory_backend, exempt_paths={"/health"}, exempt_ips={"192.168.1.100"}, ) @app.get("/api/resource") async def resource() -> dict[str, str]: return {"message": "success"} @app.get("/health") async def health() -> dict[str, str]: return {"status": "ok"} return app @pytest.fixture async def middleware_client( app_with_middleware: FastAPI, ) -> AsyncGenerator[AsyncClient, None]: """Create an async test client for middleware tests.""" transport = ASGITransport(app=app_with_middleware) async with AsyncClient(transport=transport, base_url="http://test") as client: yield client class MockRequest: """Mock request object for unit tests.""" def __init__( self, path: str = "/test", method: str = "GET", client_host: str = "127.0.0.1", headers: dict[str, str] | None = None, ) -> None: self.url = type("URL", (), {"path": path})() self.method = method self.client = type("Client", (), {"host": client_host})() self._headers = headers or {} @property def headers(self) -> dict[str, str]: return self._headers def get(self, key: str, default: str | None = None) -> str | None: return self._headers.get(key, default) @pytest.fixture def mock_request() -> MockRequest: """Create a mock request for unit tests.""" return MockRequest()