"""Examples demonstrating custom rate limit responses and callbacks.""" from __future__ import annotations import logging from contextlib import asynccontextmanager from datetime import datetime, timezone from typing import Any from fastapi import FastAPI, Request from fastapi.responses import HTMLResponse, JSONResponse, PlainTextResponse from fastapi_traffic import ( MemoryBackend, RateLimiter, RateLimitExceeded, rate_limit, ) from fastapi_traffic.core.limiter import set_limiter logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) backend = MemoryBackend() limiter = RateLimiter(backend) @asynccontextmanager async def lifespan(_: FastAPI): await limiter.initialize() set_limiter(limiter) yield await limiter.close() DEFAULT_HOST = "127.0.0.1" DEFAULT_PORT = 8000 app = FastAPI(title="Custom Responses Example", lifespan=lifespan) # 1. Standard JSON error response @app.exception_handler(RateLimitExceeded) async def json_rate_limit_handler( request: Request, exc: RateLimitExceeded ) -> JSONResponse: """Standard JSON response for API clients.""" headers = exc.limit_info.to_headers() if exc.limit_info else {} return JSONResponse( status_code=429, content={ "error": { "code": "RATE_LIMIT_EXCEEDED", "message": exc.message, "retry_after_seconds": exc.retry_after, "documentation_url": "https://docs.example.com/rate-limits", }, "request_id": request.headers.get("X-Request-ID", "unknown"), "timestamp": datetime.now(timezone.utc).isoformat(), }, headers=headers, ) # 2. Callback for logging/monitoring when requests are blocked async def log_blocked_request(request: Request, info: Any) -> None: """Log blocked requests for monitoring.""" client_ip = request.client.host if request.client else "unknown" path = request.url.path user_agent = request.headers.get("User-Agent", "unknown") logger.warning( "Rate limit exceeded: ip=%s path=%s user_agent=%s remaining=%s", client_ip, path, user_agent, info.remaining if info else "unknown", ) # In production, you might: # - Send to metrics system (Prometheus, DataDog, etc.) # - Trigger alerts for suspicious patterns # - Update a blocklist for repeat offenders @app.get("/api/monitored") @rate_limit( limit=5, window_size=60, on_blocked=log_blocked_request, ) async def monitored_endpoint(_: Request) -> dict[str, str]: """Endpoint with blocked request logging.""" return {"message": "Success"} # 3. Custom error messages per endpoint @app.get("/api/search") @rate_limit( limit=10, window_size=60, error_message="Search rate limit exceeded. Please wait before searching again.", ) async def search_endpoint(_: Request, q: str = "") -> dict[str, Any]: """Search with custom error message.""" return {"query": q, "results": []} @app.get("/api/upload") @rate_limit( limit=5, window_size=300, # 5 uploads per 5 minutes error_message="Upload limit reached. You can upload 5 files every 5 minutes.", ) async def upload_endpoint(_: Request) -> dict[str, str]: """Upload with custom error message.""" return {"message": "Upload successful"} # 4. Different response formats based on Accept header @app.get("/api/flexible") @rate_limit(limit=10, window_size=60) async def flexible_endpoint(_: Request) -> dict[str, str]: """Endpoint that returns different formats.""" return {"message": "Success", "data": "Some data"} # Custom exception handler that respects Accept header @app.exception_handler(RateLimitExceeded) async def flexible_rate_limit_handler(request: Request, exc: RateLimitExceeded): """Return response in format matching Accept header.""" accept = request.headers.get("Accept", "application/json") headers = exc.limit_info.to_headers() if exc.limit_info else {} if "text/html" in accept: html_content = f""" Rate Limit Exceeded

429 - Too Many Requests

{exc.message}

Please try again in {exc.retry_after:.0f} seconds.

""" return HTMLResponse(content=html_content, status_code=429, headers=headers) elif "text/plain" in accept: return PlainTextResponse( content=f"Rate limit exceeded. Retry after {exc.retry_after:.0f} seconds.", status_code=429, headers=headers, ) else: return JSONResponse( status_code=429, content={ "error": "rate_limit_exceeded", "message": exc.message, "retry_after": exc.retry_after, }, headers=headers, ) # 5. Include helpful information in response headers @app.get("/api/verbose-headers") @rate_limit( limit=10, window_size=60, include_headers=True, # Includes X-RateLimit-* headers ) async def verbose_headers_endpoint(_: Request) -> dict[str, Any]: """Response includes detailed rate limit headers.""" return { "message": "Check response headers for rate limit info", "headers_included": [ "X-RateLimit-Limit", "X-RateLimit-Remaining", "X-RateLimit-Reset", ], } # 6. Graceful degradation - return cached/stale data instead of error cached_data = { "data": "Cached response", "cached_at": datetime.now(timezone.utc).isoformat(), } async def return_cached_on_limit(_: Request, __: Any) -> None: """Log when rate limited (callback doesn't prevent exception).""" logger.info("Returning cached data due to rate limit") # This callback is called when blocked, but doesn't prevent the exception # To actually return cached data, you'd need custom middleware @app.get("/api/graceful") @rate_limit( limit=5, window_size=60, on_blocked=return_cached_on_limit, ) async def graceful_endpoint(_: Request) -> dict[str, str]: """Endpoint with graceful degradation.""" return { "message": "Fresh data", "timestamp": datetime.now(timezone.utc).isoformat(), } if __name__ == "__main__": import argparse import uvicorn parser = argparse.ArgumentParser(description="Custom responses example") parser.add_argument( "--host", default=DEFAULT_HOST, help=f"Host to bind to (default: {DEFAULT_HOST})", ) parser.add_argument( "--port", type=int, default=DEFAULT_PORT, help=f"Port to bind to (default: {DEFAULT_PORT})", ) args = parser.parse_args() uvicorn.run(app, host=args.host, port=args.port)