"""Examples demonstrating rate limiting with FastAPI dependency injection.""" from __future__ import annotations from contextlib import asynccontextmanager from typing import Any from fastapi import Depends, FastAPI, Request from fastapi.responses import JSONResponse from fastapi_traffic import ( MemoryBackend, RateLimiter, RateLimitExceeded, ) from fastapi_traffic.core.decorator import RateLimitDependency from fastapi_traffic.core.limiter import set_limiter backend = MemoryBackend() limiter = RateLimiter(backend) @asynccontextmanager async def lifespan(_: FastAPI): """Lifespan context manager for startup/shutdown.""" await limiter.initialize() set_limiter(limiter) yield await limiter.close() app = FastAPI(title="Dependency Injection Example", lifespan=lifespan) @app.exception_handler(RateLimitExceeded) async def rate_limit_handler(_: Request, exc: RateLimitExceeded) -> JSONResponse: return JSONResponse( status_code=429, content={"error": "rate_limit_exceeded", "retry_after": exc.retry_after}, ) # 1. Basic dependency - rate limit info available in endpoint basic_rate_limit = RateLimitDependency(limit=10, window_size=60) @app.get("/basic") async def basic_endpoint( _: Request, rate_info: Any = Depends(basic_rate_limit), ) -> dict[str, Any]: """Access rate limit info in your endpoint logic.""" return { "message": "Success", "rate_limit": { "limit": rate_info.limit, "remaining": rate_info.remaining, "reset_at": rate_info.reset_at, }, } # 2. Different limits for different user tiers def get_user_tier(request: Request) -> str: """Get user tier from header (in real app, from JWT/database).""" return request.headers.get("X-User-Tier", "free") free_tier_limit = RateLimitDependency( limit=10, window_size=60, key_prefix="free", ) pro_tier_limit = RateLimitDependency( limit=100, window_size=60, key_prefix="pro", ) enterprise_tier_limit = RateLimitDependency( limit=1000, window_size=60, key_prefix="enterprise", ) async def tiered_rate_limit( request: Request, tier: str = Depends(get_user_tier), ) -> Any: """Apply different rate limits based on user tier.""" if tier == "enterprise": return await enterprise_tier_limit(request) elif tier == "pro": return await pro_tier_limit(request) else: return await free_tier_limit(request) @app.get("/tiered") async def tiered_endpoint( request: Request, rate_info: Any = Depends(tiered_rate_limit), ) -> dict[str, Any]: """Endpoint with tier-based rate limiting.""" tier = get_user_tier(request) return { "message": "Success", "tier": tier, "rate_limit": { "limit": rate_info.limit, "remaining": rate_info.remaining, }, } # 3. Conditional rate limiting based on request properties def api_key_extractor(request: Request) -> str: """Extract API key for rate limiting.""" api_key = request.headers.get("X-API-Key", "anonymous") return f"api:{api_key}" api_rate_limit = RateLimitDependency( limit=100, window_size=3600, key_extractor=api_key_extractor, ) @app.get("/api/resource") async def api_resource( _: Request, rate_info: Any = Depends(api_rate_limit), ) -> dict[str, Any]: """API endpoint with per-API-key rate limiting.""" return { "data": "Resource data", "requests_remaining": rate_info.remaining, } # 4. Combine multiple rate limits (e.g., per-minute AND per-hour) per_minute_limit = RateLimitDependency( limit=10, window_size=60, key_prefix="minute", ) per_hour_limit = RateLimitDependency( limit=100, window_size=3600, key_prefix="hour", ) async def combined_rate_limit( _: Request, minute_info: Any = Depends(per_minute_limit), hour_info: Any = Depends(per_hour_limit), ) -> dict[str, Any]: """Apply both per-minute and per-hour limits.""" return { "minute": { "limit": minute_info.limit, "remaining": minute_info.remaining, }, "hour": { "limit": hour_info.limit, "remaining": hour_info.remaining, }, } @app.get("/combined") async def combined_endpoint( _: Request, rate_info: dict[str, Any] = Depends(combined_rate_limit), ) -> dict[str, Any]: """Endpoint with multiple rate limit tiers.""" return { "message": "Success", "rate_limits": rate_info, } # 5. Rate limit with custom exemption logic def is_internal_request(request: Request) -> bool: """Check if request is from internal service.""" internal_token = request.headers.get("X-Internal-Token") return internal_token == "internal-secret-token" internal_exempt_limit = RateLimitDependency( limit=10, window_size=60, exempt_when=is_internal_request, ) @app.get("/internal-exempt") async def internal_exempt_endpoint( request: Request, rate_info: Any = Depends(internal_exempt_limit), ) -> dict[str, Any]: """Internal requests are exempt from rate limiting.""" is_internal = is_internal_request(request) return { "message": "Success", "is_internal": is_internal, "rate_limit": ( None if is_internal else { "remaining": rate_info.remaining, } ), } if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)