"""Examples demonstrating middleware-based rate limiting.""" from __future__ import annotations from fastapi import FastAPI, Request from fastapi_traffic import MemoryBackend from fastapi_traffic.middleware import RateLimitMiddleware # Alternative middleware options (uncomment to use): # from fastapi_traffic.middleware import SlidingWindowMiddleware # from fastapi_traffic.middleware import TokenBucketMiddleware DEFAULT_HOST = "127.0.0.1" DEFAULT_PORT = 8001 app = FastAPI(title="Middleware Rate Limiting") # Custom key extractor for middleware def get_client_identifier(request: Request) -> str: """Extract client identifier from request.""" # Check for API key first api_key = request.headers.get("X-API-Key") if api_key: return f"api_key:{api_key}" # Fall back to IP if request.client: return f"ip:{request.client.host}" return "unknown" # Option 1: Basic middleware with defaults # Uncomment to use: # app.add_middleware( # RateLimitMiddleware, # limit=100, # window_size=60, # ) # Option 2: Middleware with custom configuration app.add_middleware( RateLimitMiddleware, limit=100, window_size=60, backend=MemoryBackend(), key_prefix="global", include_headers=True, error_message="You have exceeded the rate limit. Please slow down.", status_code=429, skip_on_error=True, # Don't block requests if backend fails exempt_paths={"/health", "/docs", "/openapi.json", "/redoc"}, exempt_ips={"127.0.0.1"}, # Exempt localhost key_extractor=get_client_identifier, ) # Option 3: Convenience middleware for specific algorithms # SlidingWindowMiddleware - precise rate limiting # app.add_middleware( # SlidingWindowMiddleware, # limit=100, # window_size=60, # ) # TokenBucketMiddleware - allows bursts # app.add_middleware( # TokenBucketMiddleware, # limit=100, # window_size=60, # ) @app.get("/") async def root() -> dict[str, str]: """Root endpoint - rate limited by middleware.""" return {"message": "Hello, World!"} @app.get("/api/data") async def get_data() -> dict[str, str]: """API endpoint - rate limited by middleware.""" return {"data": "Some important data"} @app.get("/api/users") async def get_users() -> dict[str, list[str]]: """Users endpoint - rate limited by middleware.""" return {"users": ["alice", "bob", "charlie"]} @app.get("/health") async def health() -> dict[str, str]: """Health check - exempt from rate limiting.""" return {"status": "healthy"} @app.get("/docs-info") async def docs_info() -> dict[str, str]: """Info about documentation endpoints.""" return { "message": "Visit /docs for Swagger UI or /redoc for ReDoc", "note": "These endpoints are exempt from rate limiting", } if __name__ == "__main__": import argparse import uvicorn parser = argparse.ArgumentParser(description="Middleware rate limiting example") parser.add_argument( "--host", default=DEFAULT_HOST, help=f"Host to bind to (default: {DEFAULT_HOST})", ) parser.add_argument( "--port", type=int, default=DEFAULT_PORT, help=f"Port to bind to (default: {DEFAULT_PORT})", ) args = parser.parse_args() uvicorn.run(app, host=args.host, port=args.port)