"""Examples demonstrating custom key extractors for rate limiting.""" from __future__ import annotations from contextlib import asynccontextmanager from fastapi import FastAPI, Request from fastapi.responses import JSONResponse from fastapi_traffic import ( MemoryBackend, RateLimitExceeded, RateLimiter, rate_limit, ) from fastapi_traffic.core.limiter import set_limiter backend = MemoryBackend() limiter = RateLimiter(backend) @asynccontextmanager async def lifespan(app: FastAPI): await limiter.initialize() set_limiter(limiter) yield await limiter.close() app = FastAPI(title="Custom Key Extractors", lifespan=lifespan) @app.exception_handler(RateLimitExceeded) async def rate_limit_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse: return JSONResponse( status_code=429, content={"error": "rate_limit_exceeded", "retry_after": exc.retry_after}, ) # 1. Default: Rate limit by IP address @app.get("/by-ip") @rate_limit(10, 60) # Uses default IP-based key extractor async def by_ip(request: Request) -> dict[str, str]: """Rate limited by client IP address (default behavior).""" return {"limited_by": "ip", "client_ip": request.client.host if request.client else "unknown"} # 2. Rate limit by API key def api_key_extractor(request: Request) -> str: """Extract API key from header.""" api_key = request.headers.get("X-API-Key", "anonymous") return f"api_key:{api_key}" @app.get("/by-api-key") @rate_limit( limit=100, window_size=3600, # 100 requests per hour per API key key_extractor=api_key_extractor, ) async def by_api_key(request: Request) -> dict[str, str]: """Rate limited by API key from X-API-Key header.""" api_key = request.headers.get("X-API-Key", "anonymous") return {"limited_by": "api_key", "api_key": api_key} # 3. Rate limit by user ID (from JWT or session) def user_id_extractor(request: Request) -> str: """Extract user ID from request state or header.""" # In real apps, this would come from decoded JWT or session user_id = request.headers.get("X-User-ID", "anonymous") return f"user:{user_id}" @app.get("/by-user") @rate_limit( limit=50, window_size=60, key_extractor=user_id_extractor, ) async def by_user(request: Request) -> dict[str, str]: """Rate limited by user ID.""" user_id = request.headers.get("X-User-ID", "anonymous") return {"limited_by": "user_id", "user_id": user_id} # 4. Rate limit by endpoint + IP (separate limits per endpoint) def endpoint_ip_extractor(request: Request) -> str: """Combine endpoint path with IP for per-endpoint limits.""" ip = request.client.host if request.client else "unknown" path = request.url.path return f"endpoint:{path}:ip:{ip}" @app.get("/endpoint-specific") @rate_limit( limit=5, window_size=60, key_extractor=endpoint_ip_extractor, ) async def endpoint_specific(request: Request) -> dict[str, str]: """Each endpoint has its own rate limit counter.""" return {"limited_by": "endpoint+ip"} # 5. Rate limit by organization/tenant (multi-tenant apps) def tenant_extractor(request: Request) -> str: """Extract tenant from subdomain or header.""" # Could also parse from subdomain: tenant.example.com tenant = request.headers.get("X-Tenant-ID", "default") return f"tenant:{tenant}" @app.get("/by-tenant") @rate_limit( limit=1000, window_size=3600, # 1000 requests per hour per tenant key_extractor=tenant_extractor, ) async def by_tenant(request: Request) -> dict[str, str]: """Rate limited by tenant/organization.""" tenant = request.headers.get("X-Tenant-ID", "default") return {"limited_by": "tenant", "tenant_id": tenant} # 6. Composite key: User + Action type def user_action_extractor(request: Request) -> str: """Rate limit specific actions per user.""" user_id = request.headers.get("X-User-ID", "anonymous") action = request.query_params.get("action", "default") return f"user:{user_id}:action:{action}" @app.get("/user-action") @rate_limit( limit=10, window_size=60, key_extractor=user_action_extractor, ) async def user_action( request: Request, action: str = "default", ) -> dict[str, str]: """Rate limited by user + action combination.""" user_id = request.headers.get("X-User-ID", "anonymous") return {"limited_by": "user+action", "user_id": user_id, "action": action} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)