Initial commit: fastapi-traffic rate limiting library
- Core rate limiting with multiple algorithms (sliding window, token bucket, etc.) - SQLite and memory backends - Decorator and dependency injection patterns - Middleware support - Example usage files
This commit is contained in:
153
examples/04_key_extractors.py
Normal file
153
examples/04_key_extractors.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""Examples demonstrating custom key extractors for rate limiting."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from fastapi_traffic import (
|
||||
MemoryBackend,
|
||||
RateLimitExceeded,
|
||||
RateLimiter,
|
||||
rate_limit,
|
||||
)
|
||||
from fastapi_traffic.core.limiter import set_limiter
|
||||
|
||||
backend = MemoryBackend()
|
||||
limiter = RateLimiter(backend)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
await limiter.initialize()
|
||||
set_limiter(limiter)
|
||||
yield
|
||||
await limiter.close()
|
||||
|
||||
|
||||
app = FastAPI(title="Custom Key Extractors", lifespan=lifespan)
|
||||
|
||||
|
||||
@app.exception_handler(RateLimitExceeded)
|
||||
async def rate_limit_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={"error": "rate_limit_exceeded", "retry_after": exc.retry_after},
|
||||
)
|
||||
|
||||
|
||||
# 1. Default: Rate limit by IP address
|
||||
@app.get("/by-ip")
|
||||
@rate_limit(10, 60) # Uses default IP-based key extractor
|
||||
async def by_ip(request: Request) -> dict[str, str]:
|
||||
"""Rate limited by client IP address (default behavior)."""
|
||||
return {"limited_by": "ip", "client_ip": request.client.host if request.client else "unknown"}
|
||||
|
||||
|
||||
# 2. Rate limit by API key
|
||||
def api_key_extractor(request: Request) -> str:
|
||||
"""Extract API key from header."""
|
||||
api_key = request.headers.get("X-API-Key", "anonymous")
|
||||
return f"api_key:{api_key}"
|
||||
|
||||
|
||||
@app.get("/by-api-key")
|
||||
@rate_limit(
|
||||
limit=100,
|
||||
window_size=3600, # 100 requests per hour per API key
|
||||
key_extractor=api_key_extractor,
|
||||
)
|
||||
async def by_api_key(request: Request) -> dict[str, str]:
|
||||
"""Rate limited by API key from X-API-Key header."""
|
||||
api_key = request.headers.get("X-API-Key", "anonymous")
|
||||
return {"limited_by": "api_key", "api_key": api_key}
|
||||
|
||||
|
||||
# 3. Rate limit by user ID (from JWT or session)
|
||||
def user_id_extractor(request: Request) -> str:
|
||||
"""Extract user ID from request state or header."""
|
||||
# In real apps, this would come from decoded JWT or session
|
||||
user_id = request.headers.get("X-User-ID", "anonymous")
|
||||
return f"user:{user_id}"
|
||||
|
||||
|
||||
@app.get("/by-user")
|
||||
@rate_limit(
|
||||
limit=50,
|
||||
window_size=60,
|
||||
key_extractor=user_id_extractor,
|
||||
)
|
||||
async def by_user(request: Request) -> dict[str, str]:
|
||||
"""Rate limited by user ID."""
|
||||
user_id = request.headers.get("X-User-ID", "anonymous")
|
||||
return {"limited_by": "user_id", "user_id": user_id}
|
||||
|
||||
|
||||
# 4. Rate limit by endpoint + IP (separate limits per endpoint)
|
||||
def endpoint_ip_extractor(request: Request) -> str:
|
||||
"""Combine endpoint path with IP for per-endpoint limits."""
|
||||
ip = request.client.host if request.client else "unknown"
|
||||
path = request.url.path
|
||||
return f"endpoint:{path}:ip:{ip}"
|
||||
|
||||
|
||||
@app.get("/endpoint-specific")
|
||||
@rate_limit(
|
||||
limit=5,
|
||||
window_size=60,
|
||||
key_extractor=endpoint_ip_extractor,
|
||||
)
|
||||
async def endpoint_specific(request: Request) -> dict[str, str]:
|
||||
"""Each endpoint has its own rate limit counter."""
|
||||
return {"limited_by": "endpoint+ip"}
|
||||
|
||||
|
||||
# 5. Rate limit by organization/tenant (multi-tenant apps)
|
||||
def tenant_extractor(request: Request) -> str:
|
||||
"""Extract tenant from subdomain or header."""
|
||||
# Could also parse from subdomain: tenant.example.com
|
||||
tenant = request.headers.get("X-Tenant-ID", "default")
|
||||
return f"tenant:{tenant}"
|
||||
|
||||
|
||||
@app.get("/by-tenant")
|
||||
@rate_limit(
|
||||
limit=1000,
|
||||
window_size=3600, # 1000 requests per hour per tenant
|
||||
key_extractor=tenant_extractor,
|
||||
)
|
||||
async def by_tenant(request: Request) -> dict[str, str]:
|
||||
"""Rate limited by tenant/organization."""
|
||||
tenant = request.headers.get("X-Tenant-ID", "default")
|
||||
return {"limited_by": "tenant", "tenant_id": tenant}
|
||||
|
||||
|
||||
# 6. Composite key: User + Action type
|
||||
def user_action_extractor(request: Request) -> str:
|
||||
"""Rate limit specific actions per user."""
|
||||
user_id = request.headers.get("X-User-ID", "anonymous")
|
||||
action = request.query_params.get("action", "default")
|
||||
return f"user:{user_id}:action:{action}"
|
||||
|
||||
|
||||
@app.get("/user-action")
|
||||
@rate_limit(
|
||||
limit=10,
|
||||
window_size=60,
|
||||
key_extractor=user_action_extractor,
|
||||
)
|
||||
async def user_action(
|
||||
request: Request,
|
||||
action: str = "default",
|
||||
) -> dict[str, str]:
|
||||
"""Rate limited by user + action combination."""
|
||||
user_id = request.headers.get("X-User-ID", "anonymous")
|
||||
return {"limited_by": "user+action", "user_id": user_id, "action": action}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
Reference in New Issue
Block a user