Files
fastapi-traffic/examples/06_dependency_injection.py
zanewalker f3453cb0fc release: bump version to 0.3.0
- Refactor Redis backend connection handling and pool management
- Update algorithm implementations with improved type annotations
- Enhance config loader validation with stricter Pydantic schemas
- Improve decorator and middleware error handling
- Expand example scripts with better docstrings and usage patterns
- Add new 00_basic_usage.py example for quick start
- Reorganize examples directory structure
- Fix type annotation inconsistencies across core modules
- Update dependencies in pyproject.toml
2026-03-17 21:04:34 +00:00

268 lines
6.5 KiB
Python

"""Examples demonstrating rate limiting with FastAPI dependency injection."""
from __future__ import annotations
from contextlib import asynccontextmanager
from typing import Annotated, Any, TypeAlias
from fastapi import Depends, FastAPI, Request
from fastapi.responses import JSONResponse
from fastapi_traffic import (
MemoryBackend,
RateLimiter,
RateLimitExceeded,
)
from fastapi_traffic.core.decorator import RateLimitDependency
from fastapi_traffic.core.limiter import set_limiter
from fastapi_traffic.core.models import RateLimitInfo
DEFAULT_HOST = "127.0.0.1"
DEFAULT_PORT = 8000
backend = MemoryBackend()
limiter = RateLimiter(backend)
@asynccontextmanager
async def lifespan(_: FastAPI):
"""Lifespan context manager for startup/shutdown."""
await limiter.initialize()
set_limiter(limiter)
yield
await limiter.close()
app = FastAPI(title="Dependency Injection Example", lifespan=lifespan)
@app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(_: Request, exc: RateLimitExceeded) -> JSONResponse:
return JSONResponse(
status_code=429,
content={"error": "rate_limit_exceeded", "retry_after": exc.retry_after},
)
# 1. Basic dependency - rate limit info available in endpoint
basic_rate_limit = RateLimitDependency(limit=10, window_size=60)
free_tier_limit = RateLimitDependency(
limit=10,
window_size=60,
key_prefix="free",
)
pro_tier_limit = RateLimitDependency(
limit=100,
window_size=60,
key_prefix="pro",
)
enterprise_tier_limit = RateLimitDependency(
limit=1000,
window_size=60,
key_prefix="enterprise",
)
BasicRateLimit: TypeAlias = Annotated[RateLimitInfo, Depends(basic_rate_limit)]
@app.get("/basic")
async def basic_endpoint(
_: Request,
rate_info: BasicRateLimit,
) -> dict[str, Any]:
"""Access rate limit info in your endpoint logic."""
return {
"message": "Success",
"rate_limit": {
"limit": rate_info.limit,
"remaining": rate_info.remaining,
"reset_at": rate_info.reset_at,
},
}
# 2. Different limits for different user tiers
def get_user_tier(request: Request) -> str:
"""Get user tier from header (in real app, from JWT/database)."""
return request.headers.get("X-User-Tier", "free")
TierDep: TypeAlias = Annotated[str, Depends(get_user_tier)]
async def tiered_rate_limit(
request: Request,
tier: TierDep,
) -> RateLimitInfo:
"""Apply different rate limits based on user tier."""
if tier == "enterprise":
return await enterprise_tier_limit(request)
elif tier == "pro":
return await pro_tier_limit(request)
else:
return await free_tier_limit(request)
TieredRateLimit: TypeAlias = Annotated[RateLimitInfo, Depends(tiered_rate_limit)]
@app.get("/tiered")
async def tiered_endpoint(
request: Request,
rate_info: TieredRateLimit,
) -> dict[str, Any]:
"""Endpoint with tier-based rate limiting."""
tier = get_user_tier(request)
return {
"message": "Success",
"tier": tier,
"rate_limit": {
"limit": rate_info.limit,
"remaining": rate_info.remaining,
},
}
# 3. Conditional rate limiting based on request properties
def api_key_extractor(request: Request) -> str:
"""Extract API key for rate limiting."""
api_key = request.headers.get("X-API-Key", "anonymous")
return f"api:{api_key}"
api_rate_limit = RateLimitDependency(
limit=100,
window_size=3600,
key_extractor=api_key_extractor,
)
ApiRateLimit: TypeAlias = Annotated[RateLimitInfo, Depends(api_rate_limit)]
@app.get("/api/resource")
async def api_resource(
_: Request,
rate_info: ApiRateLimit,
) -> dict[str, Any]:
"""API endpoint with per-API-key rate limiting."""
return {
"data": "Resource data",
"requests_remaining": rate_info.remaining,
}
# 4. Combine multiple rate limits (e.g., per-minute AND per-hour)
per_minute_limit = RateLimitDependency(
limit=10,
window_size=60,
key_prefix="minute",
)
per_hour_limit = RateLimitDependency(
limit=100,
window_size=3600,
key_prefix="hour",
)
PerMinuteLimit: TypeAlias = Annotated[RateLimitInfo, Depends(per_minute_limit)]
PerHourLimit: TypeAlias = Annotated[RateLimitInfo, Depends(per_hour_limit)]
async def combined_rate_limit(
_: Request,
minute_info: PerMinuteLimit,
hour_info: PerHourLimit,
) -> dict[str, Any]:
"""Apply both per-minute and per-hour limits."""
return {
"minute": {
"limit": minute_info.limit,
"remaining": minute_info.remaining,
},
"hour": {
"limit": hour_info.limit,
"remaining": hour_info.remaining,
},
}
CombinedRateLimit: TypeAlias = Annotated[dict[str, Any], Depends(combined_rate_limit)]
@app.get("/combined")
async def combined_endpoint(
_: Request,
rate_info: CombinedRateLimit,
) -> dict[str, Any]:
"""Endpoint with multiple rate limit tiers."""
return {
"message": "Success",
"rate_limits": rate_info,
}
# 5. Rate limit with custom exemption logic
def is_internal_request(request: Request) -> bool:
"""Check if request is from internal service."""
internal_token = request.headers.get("X-Internal-Token")
return internal_token == "internal-secret-token"
internal_exempt_limit = RateLimitDependency(
limit=10,
window_size=60,
exempt_when=is_internal_request,
)
InternalExemptLimit: TypeAlias = Annotated[
RateLimitInfo, Depends(internal_exempt_limit)
]
@app.get("/internal-exempt")
async def internal_exempt_endpoint(
request: Request,
rate_info: InternalExemptLimit,
) -> dict[str, Any]:
"""Internal requests are exempt from rate limiting."""
is_internal = is_internal_request(request)
return {
"message": "Success",
"is_internal": is_internal,
"rate_limit": (
None
if is_internal
else {
"remaining": rate_info.remaining,
}
),
}
if __name__ == "__main__":
import argparse
import uvicorn
parser = argparse.ArgumentParser(description="Dependency injection example")
parser.add_argument(
"--host",
default=DEFAULT_HOST,
help=f"Host to bind to (default: {DEFAULT_HOST})",
)
parser.add_argument(
"--port",
type=int,
default=DEFAULT_PORT,
help=f"Port to bind to (default: {DEFAULT_PORT})",
)
args = parser.parse_args()
uvicorn.run(app, host=args.host, port=args.port)