Initial commit: fastapi-traffic rate limiting library
- Core rate limiting with multiple algorithms (sliding window, token bucket, etc.) - SQLite and memory backends - Decorator and dependency injection patterns - Middleware support - Example usage files
This commit is contained in:
221
examples/06_dependency_injection.py
Normal file
221
examples/06_dependency_injection.py
Normal file
@@ -0,0 +1,221 @@
|
||||
"""Examples demonstrating rate limiting with FastAPI dependency injection."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Depends, FastAPI, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from fastapi_traffic import (
|
||||
MemoryBackend,
|
||||
RateLimitExceeded,
|
||||
RateLimiter,
|
||||
)
|
||||
from fastapi_traffic.core.decorator import RateLimitDependency
|
||||
from fastapi_traffic.core.limiter import set_limiter
|
||||
|
||||
backend = MemoryBackend()
|
||||
limiter = RateLimiter(backend)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Lifespan context manager for startup/shutdown."""
|
||||
await limiter.initialize()
|
||||
set_limiter(limiter)
|
||||
yield
|
||||
await limiter.close()
|
||||
|
||||
|
||||
app = FastAPI(title="Dependency Injection Example", lifespan=lifespan)
|
||||
|
||||
|
||||
@app.exception_handler(RateLimitExceeded)
|
||||
async def rate_limit_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={"error": "rate_limit_exceeded", "retry_after": exc.retry_after},
|
||||
)
|
||||
|
||||
|
||||
# 1. Basic dependency - rate limit info available in endpoint
|
||||
basic_rate_limit = RateLimitDependency(limit=10, window_size=60)
|
||||
|
||||
|
||||
@app.get("/basic")
|
||||
async def basic_endpoint(
|
||||
request: Request,
|
||||
rate_info: Any = Depends(basic_rate_limit),
|
||||
) -> dict[str, Any]:
|
||||
"""Access rate limit info in your endpoint logic."""
|
||||
return {
|
||||
"message": "Success",
|
||||
"rate_limit": {
|
||||
"limit": rate_info.limit,
|
||||
"remaining": rate_info.remaining,
|
||||
"reset_at": rate_info.reset_at,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# 2. Different limits for different user tiers
|
||||
def get_user_tier(request: Request) -> str:
|
||||
"""Get user tier from header (in real app, from JWT/database)."""
|
||||
return request.headers.get("X-User-Tier", "free")
|
||||
|
||||
|
||||
free_tier_limit = RateLimitDependency(
|
||||
limit=10,
|
||||
window_size=60,
|
||||
key_prefix="free",
|
||||
)
|
||||
|
||||
pro_tier_limit = RateLimitDependency(
|
||||
limit=100,
|
||||
window_size=60,
|
||||
key_prefix="pro",
|
||||
)
|
||||
|
||||
enterprise_tier_limit = RateLimitDependency(
|
||||
limit=1000,
|
||||
window_size=60,
|
||||
key_prefix="enterprise",
|
||||
)
|
||||
|
||||
|
||||
async def tiered_rate_limit(
|
||||
request: Request,
|
||||
tier: str = Depends(get_user_tier),
|
||||
) -> Any:
|
||||
"""Apply different rate limits based on user tier."""
|
||||
if tier == "enterprise":
|
||||
return await enterprise_tier_limit(request)
|
||||
elif tier == "pro":
|
||||
return await pro_tier_limit(request)
|
||||
else:
|
||||
return await free_tier_limit(request)
|
||||
|
||||
|
||||
@app.get("/tiered")
|
||||
async def tiered_endpoint(
|
||||
request: Request,
|
||||
rate_info: Any = Depends(tiered_rate_limit),
|
||||
) -> dict[str, Any]:
|
||||
"""Endpoint with tier-based rate limiting."""
|
||||
tier = get_user_tier(request)
|
||||
return {
|
||||
"message": "Success",
|
||||
"tier": tier,
|
||||
"rate_limit": {
|
||||
"limit": rate_info.limit,
|
||||
"remaining": rate_info.remaining,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# 3. Conditional rate limiting based on request properties
|
||||
def api_key_extractor(request: Request) -> str:
|
||||
"""Extract API key for rate limiting."""
|
||||
api_key = request.headers.get("X-API-Key", "anonymous")
|
||||
return f"api:{api_key}"
|
||||
|
||||
|
||||
api_rate_limit = RateLimitDependency(
|
||||
limit=100,
|
||||
window_size=3600,
|
||||
key_extractor=api_key_extractor,
|
||||
)
|
||||
|
||||
|
||||
@app.get("/api/resource")
|
||||
async def api_resource(
|
||||
request: Request,
|
||||
rate_info: Any = Depends(api_rate_limit),
|
||||
) -> dict[str, Any]:
|
||||
"""API endpoint with per-API-key rate limiting."""
|
||||
return {
|
||||
"data": "Resource data",
|
||||
"requests_remaining": rate_info.remaining,
|
||||
}
|
||||
|
||||
|
||||
# 4. Combine multiple rate limits (e.g., per-minute AND per-hour)
|
||||
per_minute_limit = RateLimitDependency(
|
||||
limit=10,
|
||||
window_size=60,
|
||||
key_prefix="minute",
|
||||
)
|
||||
|
||||
per_hour_limit = RateLimitDependency(
|
||||
limit=100,
|
||||
window_size=3600,
|
||||
key_prefix="hour",
|
||||
)
|
||||
|
||||
|
||||
async def combined_rate_limit(
|
||||
request: Request,
|
||||
minute_info: Any = Depends(per_minute_limit),
|
||||
hour_info: Any = Depends(per_hour_limit),
|
||||
) -> dict[str, Any]:
|
||||
"""Apply both per-minute and per-hour limits."""
|
||||
return {
|
||||
"minute": {
|
||||
"limit": minute_info.limit,
|
||||
"remaining": minute_info.remaining,
|
||||
},
|
||||
"hour": {
|
||||
"limit": hour_info.limit,
|
||||
"remaining": hour_info.remaining,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@app.get("/combined")
|
||||
async def combined_endpoint(
|
||||
request: Request,
|
||||
rate_info: dict[str, Any] = Depends(combined_rate_limit),
|
||||
) -> dict[str, Any]:
|
||||
"""Endpoint with multiple rate limit tiers."""
|
||||
return {
|
||||
"message": "Success",
|
||||
"rate_limits": rate_info,
|
||||
}
|
||||
|
||||
|
||||
# 5. Rate limit with custom exemption logic
|
||||
def is_internal_request(request: Request) -> bool:
|
||||
"""Check if request is from internal service."""
|
||||
internal_token = request.headers.get("X-Internal-Token")
|
||||
return internal_token == "internal-secret-token"
|
||||
|
||||
|
||||
internal_exempt_limit = RateLimitDependency(
|
||||
limit=10,
|
||||
window_size=60,
|
||||
exempt_when=is_internal_request,
|
||||
)
|
||||
|
||||
|
||||
@app.get("/internal-exempt")
|
||||
async def internal_exempt_endpoint(
|
||||
request: Request,
|
||||
rate_info: Any = Depends(internal_exempt_limit),
|
||||
) -> dict[str, Any]:
|
||||
"""Internal requests are exempt from rate limiting."""
|
||||
is_internal = is_internal_request(request)
|
||||
return {
|
||||
"message": "Success",
|
||||
"is_internal": is_internal,
|
||||
"rate_limit": None if is_internal else {
|
||||
"remaining": rate_info.remaining,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
Reference in New Issue
Block a user