"""Basic usage examples for fastapi-traffic.""" from __future__ import annotations from contextlib import asynccontextmanager from typing import TYPE_CHECKING from fastapi import Depends, FastAPI, Request from fastapi.responses import JSONResponse from fastapi_traffic import ( Algorithm, RateLimiter, RateLimitExceeded, SQLiteBackend, rate_limit, ) from fastapi_traffic.core.decorator import RateLimitDependency from fastapi_traffic.core.limiter import set_limiter if TYPE_CHECKING: from collections.abc import AsyncIterator # Configure global rate limiter with SQLite backend for persistence backend = SQLiteBackend("rate_limits.db") limiter = RateLimiter(backend) set_limiter(limiter) @asynccontextmanager async def lifespan(_: FastAPI) -> AsyncIterator[None]: """Manage application lifespan - startup and shutdown.""" # Startup: Initialize the rate limiter await limiter.initialize() yield # Shutdown: Cleanup await limiter.close() app = FastAPI(title="FastAPI Traffic Example", lifespan=lifespan) # Exception handler for rate limit exceeded @app.exception_handler(RateLimitExceeded) async def rate_limit_handler(_: Request, exc: RateLimitExceeded) -> JSONResponse: """Handle rate limit exceeded exceptions.""" headers = exc.limit_info.to_headers() if exc.limit_info else {} return JSONResponse( status_code=429, content={ "error": "rate_limit_exceeded", "message": exc.message, "retry_after": exc.retry_after, }, headers=headers, ) # Example 1: Basic decorator usage @app.get("/api/basic") @rate_limit(100, 60) # 100 requests per minute async def basic_endpoint(_: Request) -> dict[str, str]: """Basic rate-limited endpoint.""" return {"message": "Hello, World!"} # Example 2: Custom algorithm @app.get("/api/token-bucket") @rate_limit( limit=50, window_size=60, algorithm=Algorithm.TOKEN_BUCKET, burst_size=10, # Allow bursts of up to 10 requests ) async def token_bucket_endpoint(_: Request) -> dict[str, str]: """Endpoint using token bucket algorithm.""" return {"message": "Token bucket rate limiting"} # Example 3: Sliding window for precise rate limiting @app.get("/api/sliding-window") @rate_limit( limit=30, window_size=60, algorithm=Algorithm.SLIDING_WINDOW, ) async def sliding_window_endpoint(_: Request) -> dict[str, str]: """Endpoint using sliding window algorithm.""" return {"message": "Sliding window rate limiting"} # Example 4: Custom key extractor (rate limit by API key) def api_key_extractor(request: Request) -> str: """Extract API key from header for rate limiting.""" api_key = request.headers.get("X-API-Key", "anonymous") return f"api_key:{api_key}" @app.get("/api/by-api-key") @rate_limit( limit=1000, window_size=3600, # 1000 requests per hour key_extractor=api_key_extractor, ) async def api_key_endpoint(_: Request) -> dict[str, str]: """Endpoint rate limited by API key.""" return {"message": "Rate limited by API key"} # Example 5: Using dependency injection rate_limit_dep = RateLimitDependency(limit=20, window_size=60) @app.get("/api/dependency") async def dependency_endpoint( _: Request, rate_info: dict[str, object] = Depends(rate_limit_dep), ) -> dict[str, object]: """Endpoint using rate limit as dependency.""" return { "message": "Rate limit info available", "rate_limit": rate_info, } # Example 6: Exempt certain requests def is_admin(request: Request) -> bool: """Check if request is from admin.""" return request.headers.get("X-Admin-Token") == "secret-admin-token" @app.get("/api/admin-exempt") @rate_limit( limit=10, window_size=60, exempt_when=is_admin, ) async def admin_exempt_endpoint(_: Request) -> dict[str, str]: """Endpoint with admin exemption.""" return {"message": "Admins are exempt from rate limiting"} # Example 7: Different costs for different operations @app.post("/api/expensive") @rate_limit( limit=100, window_size=60, cost=10, # This endpoint costs 10 tokens per request ) async def expensive_endpoint(_: Request) -> dict[str, str]: """Expensive operation that costs more tokens.""" return {"message": "Expensive operation completed"} # Example 8: Global middleware rate limiting # Uncomment to enable global rate limiting # app.add_middleware( # RateLimitMiddleware, # limit=1000, # window_size=60, # exempt_paths={"/health", "/docs", "/openapi.json"}, # ) @app.get("/health") async def health_check() -> dict[str, str]: """Health check endpoint (typically exempt from rate limiting).""" return {"status": "healthy"} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)