Files
fastapi-traffic/examples/09_custom_responses.py
zanewalker f3453cb0fc release: bump version to 0.3.0
- Refactor Redis backend connection handling and pool management
- Update algorithm implementations with improved type annotations
- Enhance config loader validation with stricter Pydantic schemas
- Improve decorator and middleware error handling
- Expand example scripts with better docstrings and usage patterns
- Add new 00_basic_usage.py example for quick start
- Reorganize examples directory structure
- Fix type annotation inconsistencies across core modules
- Update dependencies in pyproject.toml
2026-03-17 21:04:34 +00:00

236 lines
6.7 KiB
Python

"""Examples demonstrating custom rate limit responses and callbacks."""
from __future__ import annotations
import logging
from contextlib import asynccontextmanager
from datetime import datetime, timezone
from typing import Any
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, JSONResponse, PlainTextResponse
from fastapi_traffic import (
MemoryBackend,
RateLimiter,
RateLimitExceeded,
rate_limit,
)
from fastapi_traffic.core.limiter import set_limiter
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
backend = MemoryBackend()
limiter = RateLimiter(backend)
@asynccontextmanager
async def lifespan(_: FastAPI):
await limiter.initialize()
set_limiter(limiter)
yield
await limiter.close()
DEFAULT_HOST = "127.0.0.1"
DEFAULT_PORT = 8000
app = FastAPI(title="Custom Responses Example", lifespan=lifespan)
# 1. Standard JSON error response
@app.exception_handler(RateLimitExceeded)
async def json_rate_limit_handler(
request: Request, exc: RateLimitExceeded
) -> JSONResponse:
"""Standard JSON response for API clients."""
headers = exc.limit_info.to_headers() if exc.limit_info else {}
return JSONResponse(
status_code=429,
content={
"error": {
"code": "RATE_LIMIT_EXCEEDED",
"message": exc.message,
"retry_after_seconds": exc.retry_after,
"documentation_url": "https://docs.example.com/rate-limits",
},
"request_id": request.headers.get("X-Request-ID", "unknown"),
"timestamp": datetime.now(timezone.utc).isoformat(),
},
headers=headers,
)
# 2. Callback for logging/monitoring when requests are blocked
async def log_blocked_request(request: Request, info: Any) -> None:
"""Log blocked requests for monitoring."""
client_ip = request.client.host if request.client else "unknown"
path = request.url.path
user_agent = request.headers.get("User-Agent", "unknown")
logger.warning(
"Rate limit exceeded: ip=%s path=%s user_agent=%s remaining=%s",
client_ip,
path,
user_agent,
info.remaining if info else "unknown",
)
# In production, you might:
# - Send to metrics system (Prometheus, DataDog, etc.)
# - Trigger alerts for suspicious patterns
# - Update a blocklist for repeat offenders
@app.get("/api/monitored")
@rate_limit(
limit=5,
window_size=60,
on_blocked=log_blocked_request,
)
async def monitored_endpoint(_: Request) -> dict[str, str]:
"""Endpoint with blocked request logging."""
return {"message": "Success"}
# 3. Custom error messages per endpoint
@app.get("/api/search")
@rate_limit(
limit=10,
window_size=60,
error_message="Search rate limit exceeded. Please wait before searching again.",
)
async def search_endpoint(_: Request, q: str = "") -> dict[str, Any]:
"""Search with custom error message."""
return {"query": q, "results": []}
@app.get("/api/upload")
@rate_limit(
limit=5,
window_size=300, # 5 uploads per 5 minutes
error_message="Upload limit reached. You can upload 5 files every 5 minutes.",
)
async def upload_endpoint(_: Request) -> dict[str, str]:
"""Upload with custom error message."""
return {"message": "Upload successful"}
# 4. Different response formats based on Accept header
@app.get("/api/flexible")
@rate_limit(limit=10, window_size=60)
async def flexible_endpoint(_: Request) -> dict[str, str]:
"""Endpoint that returns different formats."""
return {"message": "Success", "data": "Some data"}
# Custom exception handler that respects Accept header
@app.exception_handler(RateLimitExceeded)
async def flexible_rate_limit_handler(request: Request, exc: RateLimitExceeded):
"""Return response in format matching Accept header."""
accept = request.headers.get("Accept", "application/json")
headers = exc.limit_info.to_headers() if exc.limit_info else {}
if "text/html" in accept:
html_content = f"""
<!DOCTYPE html>
<html>
<head><title>Rate Limit Exceeded</title></head>
<body>
<h1>429 - Too Many Requests</h1>
<p>{exc.message}</p>
<p>Please try again in {exc.retry_after:.0f} seconds.</p>
</body>
</html>
"""
return HTMLResponse(content=html_content, status_code=429, headers=headers)
elif "text/plain" in accept:
return PlainTextResponse(
content=f"Rate limit exceeded. Retry after {exc.retry_after:.0f} seconds.",
status_code=429,
headers=headers,
)
else:
return JSONResponse(
status_code=429,
content={
"error": "rate_limit_exceeded",
"message": exc.message,
"retry_after": exc.retry_after,
},
headers=headers,
)
# 5. Include helpful information in response headers
@app.get("/api/verbose-headers")
@rate_limit(
limit=10,
window_size=60,
include_headers=True, # Includes X-RateLimit-* headers
)
async def verbose_headers_endpoint(_: Request) -> dict[str, Any]:
"""Response includes detailed rate limit headers."""
return {
"message": "Check response headers for rate limit info",
"headers_included": [
"X-RateLimit-Limit",
"X-RateLimit-Remaining",
"X-RateLimit-Reset",
],
}
# 6. Graceful degradation - return cached/stale data instead of error
cached_data = {
"data": "Cached response",
"cached_at": datetime.now(timezone.utc).isoformat(),
}
async def return_cached_on_limit(_: Request, __: Any) -> None:
"""Log when rate limited (callback doesn't prevent exception)."""
logger.info("Returning cached data due to rate limit")
# This callback is called when blocked, but doesn't prevent the exception
# To actually return cached data, you'd need custom middleware
@app.get("/api/graceful")
@rate_limit(
limit=5,
window_size=60,
on_blocked=return_cached_on_limit,
)
async def graceful_endpoint(_: Request) -> dict[str, str]:
"""Endpoint with graceful degradation."""
return {
"message": "Fresh data",
"timestamp": datetime.now(timezone.utc).isoformat(),
}
if __name__ == "__main__":
import argparse
import uvicorn
parser = argparse.ArgumentParser(description="Custom responses example")
parser.add_argument(
"--host",
default=DEFAULT_HOST,
help=f"Host to bind to (default: {DEFAULT_HOST})",
)
parser.add_argument(
"--port",
type=int,
default=DEFAULT_PORT,
help=f"Port to bind to (default: {DEFAULT_PORT})",
)
args = parser.parse_args()
uvicorn.run(app, host=args.host, port=args.port)