Files
fastapi-traffic/examples/09_custom_responses.py

217 lines
6.3 KiB
Python

"""Examples demonstrating custom rate limit responses and callbacks."""
from __future__ import annotations
import logging
from contextlib import asynccontextmanager
from datetime import datetime, timezone
from typing import Any
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, JSONResponse, PlainTextResponse
from fastapi_traffic import (
MemoryBackend,
RateLimiter,
RateLimitExceeded,
rate_limit,
)
from fastapi_traffic.core.limiter import set_limiter
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
backend = MemoryBackend()
limiter = RateLimiter(backend)
@asynccontextmanager
async def lifespan(_: FastAPI):
await limiter.initialize()
set_limiter(limiter)
yield
await limiter.close()
app = FastAPI(title="Custom Responses Example", lifespan=lifespan)
# 1. Standard JSON error response
@app.exception_handler(RateLimitExceeded)
async def json_rate_limit_handler(
request: Request, exc: RateLimitExceeded
) -> JSONResponse:
"""Standard JSON response for API clients."""
headers = exc.limit_info.to_headers() if exc.limit_info else {}
return JSONResponse(
status_code=429,
content={
"error": {
"code": "RATE_LIMIT_EXCEEDED",
"message": exc.message,
"retry_after_seconds": exc.retry_after,
"documentation_url": "https://docs.example.com/rate-limits",
},
"request_id": request.headers.get("X-Request-ID", "unknown"),
"timestamp": datetime.now(timezone.utc).isoformat(),
},
headers=headers,
)
# 2. Callback for logging/monitoring when requests are blocked
async def log_blocked_request(request: Request, info: Any) -> None:
"""Log blocked requests for monitoring."""
client_ip = request.client.host if request.client else "unknown"
path = request.url.path
user_agent = request.headers.get("User-Agent", "unknown")
logger.warning(
"Rate limit exceeded: ip=%s path=%s user_agent=%s remaining=%s",
client_ip,
path,
user_agent,
info.remaining if info else "unknown",
)
# In production, you might:
# - Send to metrics system (Prometheus, DataDog, etc.)
# - Trigger alerts for suspicious patterns
# - Update a blocklist for repeat offenders
@app.get("/api/monitored")
@rate_limit(
limit=5,
window_size=60,
on_blocked=log_blocked_request,
)
async def monitored_endpoint(_: Request) -> dict[str, str]:
"""Endpoint with blocked request logging."""
return {"message": "Success"}
# 3. Custom error messages per endpoint
@app.get("/api/search")
@rate_limit(
limit=10,
window_size=60,
error_message="Search rate limit exceeded. Please wait before searching again.",
)
async def search_endpoint(_: Request, q: str = "") -> dict[str, Any]:
"""Search with custom error message."""
return {"query": q, "results": []}
@app.get("/api/upload")
@rate_limit(
limit=5,
window_size=300, # 5 uploads per 5 minutes
error_message="Upload limit reached. You can upload 5 files every 5 minutes.",
)
async def upload_endpoint(_: Request) -> dict[str, str]:
"""Upload with custom error message."""
return {"message": "Upload successful"}
# 4. Different response formats based on Accept header
@app.get("/api/flexible")
@rate_limit(limit=10, window_size=60)
async def flexible_endpoint(_: Request) -> dict[str, str]:
"""Endpoint that returns different formats."""
return {"message": "Success", "data": "Some data"}
# Custom exception handler that respects Accept header
@app.exception_handler(RateLimitExceeded)
async def flexible_rate_limit_handler(request: Request, exc: RateLimitExceeded):
"""Return response in format matching Accept header."""
accept = request.headers.get("Accept", "application/json")
headers = exc.limit_info.to_headers() if exc.limit_info else {}
if "text/html" in accept:
html_content = f"""
<!DOCTYPE html>
<html>
<head><title>Rate Limit Exceeded</title></head>
<body>
<h1>429 - Too Many Requests</h1>
<p>{exc.message}</p>
<p>Please try again in {exc.retry_after:.0f} seconds.</p>
</body>
</html>
"""
return HTMLResponse(content=html_content, status_code=429, headers=headers)
elif "text/plain" in accept:
return PlainTextResponse(
content=f"Rate limit exceeded. Retry after {exc.retry_after:.0f} seconds.",
status_code=429,
headers=headers,
)
else:
return JSONResponse(
status_code=429,
content={
"error": "rate_limit_exceeded",
"message": exc.message,
"retry_after": exc.retry_after,
},
headers=headers,
)
# 5. Include helpful information in response headers
@app.get("/api/verbose-headers")
@rate_limit(
limit=10,
window_size=60,
include_headers=True, # Includes X-RateLimit-* headers
)
async def verbose_headers_endpoint(_: Request) -> dict[str, Any]:
"""Response includes detailed rate limit headers."""
return {
"message": "Check response headers for rate limit info",
"headers_included": [
"X-RateLimit-Limit",
"X-RateLimit-Remaining",
"X-RateLimit-Reset",
],
}
# 6. Graceful degradation - return cached/stale data instead of error
cached_data = {
"data": "Cached response",
"cached_at": datetime.now(timezone.utc).isoformat(),
}
async def return_cached_on_limit(_: Request, __: Any) -> None:
"""Log when rate limited (callback doesn't prevent exception)."""
logger.info("Returning cached data due to rate limit")
# This callback is called when blocked, but doesn't prevent the exception
# To actually return cached data, you'd need custom middleware
@app.get("/api/graceful")
@rate_limit(
limit=5,
window_size=60,
on_blocked=return_cached_on_limit,
)
async def graceful_endpoint(_: Request) -> dict[str, str]:
"""Endpoint with graceful degradation."""
return {
"message": "Fresh data",
"timestamp": datetime.now(timezone.utc).isoformat(),
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)