"""Advanced patterns and real-world use cases for rate limiting.""" from __future__ import annotations import hashlib import time from contextlib import asynccontextmanager from typing import Any from fastapi import Depends, FastAPI, Request from fastapi.responses import JSONResponse from fastapi_traffic import ( Algorithm, MemoryBackend, RateLimiter, RateLimitExceeded, rate_limit, ) from fastapi_traffic.core.decorator import RateLimitDependency from fastapi_traffic.core.limiter import set_limiter backend = MemoryBackend() limiter = RateLimiter(backend) @asynccontextmanager async def lifespan(_: FastAPI): await limiter.initialize() set_limiter(limiter) yield await limiter.close() app = FastAPI(title="Advanced Patterns", lifespan=lifespan) @app.exception_handler(RateLimitExceeded) async def rate_limit_handler(_: Request, exc: RateLimitExceeded) -> JSONResponse: return JSONResponse( status_code=429, content={"error": "rate_limit_exceeded", "retry_after": exc.retry_after}, headers=exc.limit_info.to_headers() if exc.limit_info else {}, ) # ============================================================================= # Pattern 1: Cost-based rate limiting # Different operations consume different amounts of quota # ============================================================================= @app.get("/api/list") @rate_limit(limit=100, window_size=60, cost=1) async def list_items(_: Request) -> dict[str, Any]: """Cheap operation - costs 1 token.""" return {"items": ["a", "b", "c"], "cost": 1} @app.get("/api/details/{item_id}") @rate_limit(limit=100, window_size=60, cost=5) async def get_details(_: Request, item_id: str) -> dict[str, Any]: """Medium operation - costs 5 tokens.""" return {"item_id": item_id, "details": "...", "cost": 5} @app.post("/api/generate") @rate_limit(limit=100, window_size=60, cost=20) async def generate_content(_: Request) -> dict[str, Any]: """Expensive operation - costs 20 tokens.""" return {"generated": "AI-generated content...", "cost": 20} @app.post("/api/bulk-export") @rate_limit(limit=100, window_size=60, cost=50) async def bulk_export(_: Request) -> dict[str, Any]: """Very expensive operation - costs 50 tokens.""" return {"export_url": "https://...", "cost": 50} # ============================================================================= # Pattern 2: Sliding scale exemptions # Gradually reduce limits instead of hard blocking # ============================================================================= def get_request_priority(request: Request) -> int: """Determine request priority (higher = more important).""" # Premium users get higher priority if request.headers.get("X-Premium-User") == "true": return 100 # Authenticated users get medium priority if request.headers.get("Authorization"): return 50 # Anonymous users get lowest priority return 10 def should_exempt_high_priority(request: Request) -> bool: """Exempt high-priority requests from rate limiting.""" return get_request_priority(request) >= 100 @app.get("/api/priority-based") @rate_limit( limit=10, window_size=60, exempt_when=should_exempt_high_priority, ) async def priority_endpoint(request: Request) -> dict[str, Any]: """Premium users are exempt from rate limits.""" priority = get_request_priority(request) return { "message": "Success", "priority": priority, "exempt": priority >= 100, } # ============================================================================= # Pattern 3: Rate limit by resource, not just user # Prevent abuse of specific resources # ============================================================================= def resource_key_extractor(request: Request) -> str: """Rate limit by resource ID + user.""" resource_id = request.path_params.get("resource_id", "unknown") user_id = request.headers.get("X-User-ID", "anonymous") return f"resource:{resource_id}:user:{user_id}" @app.get("/api/resources/{resource_id}") @rate_limit( limit=10, window_size=60, key_extractor=resource_key_extractor, ) async def get_resource(_: Request, resource_id: str) -> dict[str, str]: """Each user can access each resource 10 times per minute.""" return {"resource_id": resource_id, "data": "..."} # ============================================================================= # Pattern 4: Login/authentication rate limiting # Prevent brute force attacks # ============================================================================= def login_key_extractor(request: Request) -> str: """Rate limit by IP + username to prevent brute force.""" ip = request.client.host if request.client else "unknown" # In real app, parse username from request body username = request.headers.get("X-Username", "unknown") return f"login:{ip}:{username}" @app.post("/auth/login") @rate_limit( limit=5, window_size=300, # 5 attempts per 5 minutes algorithm=Algorithm.SLIDING_WINDOW, # Precise tracking for security key_extractor=login_key_extractor, error_message="Too many login attempts. Please try again in 5 minutes.", ) async def login(_: Request) -> dict[str, str]: """Login endpoint with brute force protection.""" return {"message": "Login successful", "token": "..."} # Password reset - even stricter limits def password_reset_key(request: Request) -> str: ip = request.client.host if request.client else "unknown" return f"password_reset:{ip}" @app.post("/auth/password-reset") @rate_limit( limit=3, window_size=3600, # 3 attempts per hour key_extractor=password_reset_key, error_message="Too many password reset requests. Please try again later.", ) async def password_reset(_: Request) -> dict[str, str]: """Password reset with strict rate limiting.""" return {"message": "Password reset email sent"} # ============================================================================= # Pattern 5: Webhook/callback rate limiting # Limit outgoing requests to prevent overwhelming external services # ============================================================================= webhook_rate_limit = RateLimitDependency( limit=100, window_size=60, key_prefix="webhook", ) async def check_webhook_limit( _: Request, webhook_url: str, ) -> None: """Check rate limit before sending webhook.""" # Create key based on destination domain from urllib.parse import urlparse domain = urlparse(webhook_url).netloc _key = f"webhook:{domain}" # Would be used with limiter in production # Manually check limit (simplified example) # In production, you'd use the limiter directly __ = _key # Suppress unused variable warning @app.post("/api/send-webhook") async def send_webhook( _: Request, webhook_url: str = "https://example.com/webhook", rate_info: Any = Depends(webhook_rate_limit), ) -> dict[str, Any]: """Send webhook with rate limiting to protect external services.""" # await check_webhook_limit(request, webhook_url) return { "message": "Webhook sent", "destination": webhook_url, "remaining_quota": rate_info.remaining, } # ============================================================================= # Pattern 6: Request fingerprinting # Detect and limit similar requests (e.g., spam prevention) # ============================================================================= def request_fingerprint(request: Request) -> str: """Create fingerprint based on request characteristics.""" ip = request.client.host if request.client else "unknown" user_agent = request.headers.get("User-Agent", "") accept_language = request.headers.get("Accept-Language", "") # Create hash of request characteristics fingerprint_data = f"{ip}:{user_agent}:{accept_language}" fingerprint = hashlib.md5(fingerprint_data.encode()).hexdigest()[:16] return f"fingerprint:{fingerprint}" @app.post("/api/submit-form") @rate_limit( limit=5, window_size=60, key_extractor=request_fingerprint, error_message="Too many submissions from this device.", ) async def submit_form(_: Request) -> dict[str, str]: """Form submission with fingerprint-based rate limiting.""" return {"message": "Form submitted successfully"} # ============================================================================= # Pattern 7: Time-of-day based limits # Different limits during peak vs off-peak hours # ============================================================================= def is_peak_hours() -> bool: """Check if current time is during peak hours (9 AM - 6 PM UTC).""" current_hour = time.gmtime().tm_hour return 9 <= current_hour < 18 def peak_aware_exempt(_: Request) -> bool: """Exempt requests during off-peak hours.""" return not is_peak_hours() @app.get("/api/peak-aware") @rate_limit( limit=10, # Strict limit during peak hours window_size=60, exempt_when=peak_aware_exempt, # No limit during off-peak ) async def peak_aware_endpoint(_: Request) -> dict[str, Any]: """Stricter limits during peak hours.""" return { "message": "Success", "is_peak_hours": is_peak_hours(), "rate_limited": is_peak_hours(), } # ============================================================================= # Pattern 8: Cascading limits (multiple tiers) # ============================================================================= per_second = RateLimitDependency(limit=5, window_size=1, key_prefix="sec") per_minute = RateLimitDependency(limit=100, window_size=60, key_prefix="min") per_hour = RateLimitDependency(limit=1000, window_size=3600, key_prefix="hour") async def cascading_limits( _: Request, sec_info: Any = Depends(per_second), min_info: Any = Depends(per_minute), hour_info: Any = Depends(per_hour), ) -> dict[str, Any]: """Apply multiple rate limit tiers.""" return { "per_second": {"remaining": sec_info.remaining}, "per_minute": {"remaining": min_info.remaining}, "per_hour": {"remaining": hour_info.remaining}, } @app.get("/api/cascading") async def cascading_endpoint( _: Request, limits: dict[str, Any] = Depends(cascading_limits), ) -> dict[str, Any]: """Endpoint with per-second, per-minute, and per-hour limits.""" return {"message": "Success", "limits": limits} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8001)