Files
fastapi-traffic/examples/04_key_extractors.py
zanewalker f3453cb0fc release: bump version to 0.3.0
- Refactor Redis backend connection handling and pool management
- Update algorithm implementations with improved type annotations
- Enhance config loader validation with stricter Pydantic schemas
- Improve decorator and middleware error handling
- Expand example scripts with better docstrings and usage patterns
- Add new 00_basic_usage.py example for quick start
- Reorganize examples directory structure
- Fix type annotation inconsistencies across core modules
- Update dependencies in pyproject.toml
2026-03-17 21:04:34 +00:00

176 lines
4.9 KiB
Python

"""Examples demonstrating custom key extractors for rate limiting."""
from __future__ import annotations
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from fastapi_traffic import (
MemoryBackend,
RateLimiter,
RateLimitExceeded,
rate_limit,
)
from fastapi_traffic.core.limiter import set_limiter
DEFAULT_HOST = "127.0.0.1"
DEFAULT_PORT = 8000
backend = MemoryBackend()
limiter = RateLimiter(backend)
@asynccontextmanager
async def lifespan(_: FastAPI):
await limiter.initialize()
set_limiter(limiter)
yield
await limiter.close()
app = FastAPI(title="Custom Key Extractors", lifespan=lifespan)
@app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(_: Request, exc: RateLimitExceeded) -> JSONResponse:
return JSONResponse(
status_code=429,
content={"error": "rate_limit_exceeded", "retry_after": exc.retry_after},
)
# 1. Default: Rate limit by IP address
@app.get("/by-ip")
@rate_limit(10, 60) # Uses default IP-based key extractor
async def by_ip(request: Request) -> dict[str, str]:
"""Rate limited by client IP address (default behavior)."""
return {
"limited_by": "ip",
"client_ip": request.client.host if request.client else "unknown",
}
# 2. Rate limit by API key
def api_key_extractor(request: Request) -> str:
"""Extract API key from header."""
api_key = request.headers.get("X-API-Key", "anonymous")
return f"api_key:{api_key}"
@app.get("/by-api-key")
@rate_limit(
limit=100,
window_size=3600, # 100 requests per hour per API key
key_extractor=api_key_extractor,
)
async def by_api_key(request: Request) -> dict[str, str]:
"""Rate limited by API key from X-API-Key header."""
api_key = request.headers.get("X-API-Key", "anonymous")
return {"limited_by": "api_key", "api_key": api_key}
# 3. Rate limit by user ID (from JWT or session)
def user_id_extractor(request: Request) -> str:
"""Extract user ID from request state or header."""
# In real apps, this would come from decoded JWT or session
user_id = request.headers.get("X-User-ID", "anonymous")
return f"user:{user_id}"
@app.get("/by-user")
@rate_limit(
limit=50,
window_size=60,
key_extractor=user_id_extractor,
)
async def by_user(request: Request) -> dict[str, str]:
"""Rate limited by user ID."""
user_id = request.headers.get("X-User-ID", "anonymous")
return {"limited_by": "user_id", "user_id": user_id}
# 4. Rate limit by endpoint + IP (separate limits per endpoint)
def endpoint_ip_extractor(request: Request) -> str:
"""Combine endpoint path with IP for per-endpoint limits."""
ip = request.client.host if request.client else "unknown"
path = request.url.path
return f"endpoint:{path}:ip:{ip}"
@app.get("/endpoint-specific")
@rate_limit(
limit=5,
window_size=60,
key_extractor=endpoint_ip_extractor,
)
async def endpoint_specific(_: Request) -> dict[str, str]:
"""Each endpoint has its own rate limit counter."""
return {"limited_by": "endpoint+ip"}
# 5. Rate limit by organization/tenant (multi-tenant apps)
def tenant_extractor(request: Request) -> str:
"""Extract tenant from subdomain or header."""
# Could also parse from subdomain: tenant.example.com
tenant = request.headers.get("X-Tenant-ID", "default")
return f"tenant:{tenant}"
@app.get("/by-tenant")
@rate_limit(
limit=1000,
window_size=3600, # 1000 requests per hour per tenant
key_extractor=tenant_extractor,
)
async def by_tenant(request: Request) -> dict[str, str]:
"""Rate limited by tenant/organization."""
tenant = request.headers.get("X-Tenant-ID", "default")
return {"limited_by": "tenant", "tenant_id": tenant}
# 6. Composite key: User + Action type
def user_action_extractor(request: Request) -> str:
"""Rate limit specific actions per user."""
user_id = request.headers.get("X-User-ID", "anonymous")
action = request.query_params.get("action", "default")
return f"user:{user_id}:action:{action}"
@app.get("/user-action")
@rate_limit(
limit=10,
window_size=60,
key_extractor=user_action_extractor,
)
async def user_action(
request: Request,
action: str = "default",
) -> dict[str, str]:
"""Rate limited by user + action combination."""
user_id = request.headers.get("X-User-ID", "anonymous")
return {"limited_by": "user+action", "user_id": user_id, "action": action}
if __name__ == "__main__":
import argparse
import uvicorn
parser = argparse.ArgumentParser(description="Custom key extractors example")
parser.add_argument(
"--host",
default=DEFAULT_HOST,
help=f"Host to bind to (default: {DEFAULT_HOST})",
)
parser.add_argument(
"--port",
type=int,
default=DEFAULT_PORT,
help=f"Port to bind to (default: {DEFAULT_PORT})",
)
args = parser.parse_args()
uvicorn.run(app, host=args.host, port=args.port)