Initial commit: fastapi-traffic rate limiting library
- Core rate limiting with multiple algorithms (sliding window, token bucket, etc.) - SQLite and memory backends - Decorator and dependency injection patterns - Middleware support - Example usage files
This commit is contained in:
256
examples/08_tiered_api.py
Normal file
256
examples/08_tiered_api.py
Normal file
@@ -0,0 +1,256 @@
|
||||
"""Example of a production-ready tiered API with different rate limits per plan."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Depends, FastAPI, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from fastapi_traffic import (
|
||||
Algorithm,
|
||||
MemoryBackend,
|
||||
RateLimitExceeded,
|
||||
RateLimiter,
|
||||
)
|
||||
from fastapi_traffic.core.decorator import RateLimitDependency
|
||||
from fastapi_traffic.core.limiter import set_limiter
|
||||
|
||||
backend = MemoryBackend()
|
||||
limiter = RateLimiter(backend)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
await limiter.initialize()
|
||||
set_limiter(limiter)
|
||||
yield
|
||||
await limiter.close()
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="Tiered API Example",
|
||||
description="API with different rate limits based on subscription tier",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
|
||||
class Tier(str, Enum):
|
||||
FREE = "free"
|
||||
STARTER = "starter"
|
||||
PRO = "pro"
|
||||
ENTERPRISE = "enterprise"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TierConfig:
|
||||
requests_per_minute: int
|
||||
requests_per_hour: int
|
||||
requests_per_day: int
|
||||
burst_size: int
|
||||
features: list[str]
|
||||
|
||||
|
||||
# Tier configurations
|
||||
TIER_CONFIGS: dict[Tier, TierConfig] = {
|
||||
Tier.FREE: TierConfig(
|
||||
requests_per_minute=10,
|
||||
requests_per_hour=100,
|
||||
requests_per_day=500,
|
||||
burst_size=5,
|
||||
features=["basic_api"],
|
||||
),
|
||||
Tier.STARTER: TierConfig(
|
||||
requests_per_minute=60,
|
||||
requests_per_hour=1000,
|
||||
requests_per_day=10000,
|
||||
burst_size=20,
|
||||
features=["basic_api", "webhooks"],
|
||||
),
|
||||
Tier.PRO: TierConfig(
|
||||
requests_per_minute=300,
|
||||
requests_per_hour=10000,
|
||||
requests_per_day=100000,
|
||||
burst_size=50,
|
||||
features=["basic_api", "webhooks", "analytics", "priority_support"],
|
||||
),
|
||||
Tier.ENTERPRISE: TierConfig(
|
||||
requests_per_minute=1000,
|
||||
requests_per_hour=50000,
|
||||
requests_per_day=500000,
|
||||
burst_size=200,
|
||||
features=["basic_api", "webhooks", "analytics", "priority_support", "sla", "custom_integrations"],
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
# Simulated API key database
|
||||
API_KEYS: dict[str, dict[str, Any]] = {
|
||||
"free-key-123": {"tier": Tier.FREE, "user_id": "user_1"},
|
||||
"starter-key-456": {"tier": Tier.STARTER, "user_id": "user_2"},
|
||||
"pro-key-789": {"tier": Tier.PRO, "user_id": "user_3"},
|
||||
"enterprise-key-000": {"tier": Tier.ENTERPRISE, "user_id": "user_4"},
|
||||
}
|
||||
|
||||
|
||||
@app.exception_handler(RateLimitExceeded)
|
||||
async def rate_limit_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
|
||||
api_key = request.headers.get("X-API-Key", "")
|
||||
key_info = API_KEYS.get(api_key, {})
|
||||
tier = key_info.get("tier", Tier.FREE)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"error": "rate_limit_exceeded",
|
||||
"message": exc.message,
|
||||
"retry_after": exc.retry_after,
|
||||
"tier": tier.value,
|
||||
"upgrade_url": "https://example.com/pricing" if tier != Tier.ENTERPRISE else None,
|
||||
},
|
||||
headers=exc.limit_info.to_headers() if exc.limit_info else {},
|
||||
)
|
||||
|
||||
|
||||
def get_api_key_info(request: Request) -> dict[str, Any]:
|
||||
"""Validate API key and return info."""
|
||||
api_key = request.headers.get("X-API-Key")
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=401, detail="API key required")
|
||||
|
||||
key_info = API_KEYS.get(api_key)
|
||||
if not key_info:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
return {"api_key": api_key, **key_info}
|
||||
|
||||
|
||||
def get_tier_config(key_info: dict[str, Any] = Depends(get_api_key_info)) -> TierConfig:
|
||||
"""Get rate limit config for user's tier."""
|
||||
tier = key_info.get("tier", Tier.FREE)
|
||||
return TIER_CONFIGS[tier]
|
||||
|
||||
|
||||
# Create rate limit dependencies for each tier
|
||||
tier_rate_limits: dict[Tier, RateLimitDependency] = {}
|
||||
for tier, config in TIER_CONFIGS.items():
|
||||
tier_rate_limits[tier] = RateLimitDependency(
|
||||
limit=config.requests_per_minute,
|
||||
window_size=60,
|
||||
algorithm=Algorithm.TOKEN_BUCKET,
|
||||
burst_size=config.burst_size,
|
||||
key_prefix=f"tier_{tier.value}",
|
||||
)
|
||||
|
||||
|
||||
def api_key_extractor(request: Request) -> str:
|
||||
"""Extract API key for rate limiting."""
|
||||
api_key = request.headers.get("X-API-Key", "anonymous")
|
||||
return f"api:{api_key}"
|
||||
|
||||
|
||||
async def apply_tier_rate_limit(
|
||||
request: Request,
|
||||
key_info: dict[str, Any] = Depends(get_api_key_info),
|
||||
) -> dict[str, Any]:
|
||||
"""Apply rate limit based on user's tier."""
|
||||
tier = key_info.get("tier", Tier.FREE)
|
||||
rate_limit_dep = tier_rate_limits[tier]
|
||||
rate_info = await rate_limit_dep(request)
|
||||
|
||||
return {
|
||||
"tier": tier,
|
||||
"config": TIER_CONFIGS[tier],
|
||||
"rate_info": rate_info,
|
||||
"key_info": key_info,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/v1/data")
|
||||
async def get_data(
|
||||
request: Request,
|
||||
limit_info: dict[str, Any] = Depends(apply_tier_rate_limit),
|
||||
) -> dict[str, Any]:
|
||||
"""Get data with tier-based rate limiting."""
|
||||
return {
|
||||
"data": {"items": ["item1", "item2", "item3"]},
|
||||
"tier": limit_info["tier"].value,
|
||||
"rate_limit": {
|
||||
"limit": limit_info["rate_info"].limit,
|
||||
"remaining": limit_info["rate_info"].remaining,
|
||||
"reset_at": limit_info["rate_info"].reset_at,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/v1/analytics")
|
||||
async def get_analytics(
|
||||
request: Request,
|
||||
limit_info: dict[str, Any] = Depends(apply_tier_rate_limit),
|
||||
) -> dict[str, Any]:
|
||||
"""Analytics endpoint - requires Pro tier or higher."""
|
||||
tier = limit_info["tier"]
|
||||
config = limit_info["config"]
|
||||
|
||||
if "analytics" not in config.features:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"Analytics requires Pro tier or higher. Current tier: {tier.value}",
|
||||
)
|
||||
|
||||
return {
|
||||
"analytics": {
|
||||
"total_requests": 12345,
|
||||
"unique_users": 567,
|
||||
"avg_response_time_ms": 45,
|
||||
},
|
||||
"tier": tier.value,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/v1/tier-info")
|
||||
async def get_tier_info(
|
||||
key_info: dict[str, Any] = Depends(get_api_key_info),
|
||||
) -> dict[str, Any]:
|
||||
"""Get information about current tier and limits."""
|
||||
tier = key_info.get("tier", Tier.FREE)
|
||||
config = TIER_CONFIGS[tier]
|
||||
|
||||
return {
|
||||
"tier": tier.value,
|
||||
"limits": {
|
||||
"requests_per_minute": config.requests_per_minute,
|
||||
"requests_per_hour": config.requests_per_hour,
|
||||
"requests_per_day": config.requests_per_day,
|
||||
"burst_size": config.burst_size,
|
||||
},
|
||||
"features": config.features,
|
||||
"upgrade_options": [t.value for t in Tier if TIER_CONFIGS[t].requests_per_minute > config.requests_per_minute],
|
||||
}
|
||||
|
||||
|
||||
@app.get("/pricing")
|
||||
async def pricing() -> dict[str, Any]:
|
||||
"""Public pricing information."""
|
||||
return {
|
||||
"tiers": {
|
||||
tier.value: {
|
||||
"requests_per_minute": config.requests_per_minute,
|
||||
"requests_per_day": config.requests_per_day,
|
||||
"features": config.features,
|
||||
}
|
||||
for tier, config in TIER_CONFIGS.items()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
# Test with different API keys:
|
||||
# curl -H "X-API-Key: free-key-123" http://localhost:8000/api/v1/data
|
||||
# curl -H "X-API-Key: pro-key-789" http://localhost:8000/api/v1/analytics
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
Reference in New Issue
Block a user