Files
fastapi-traffic/examples/08_tiered_api.py
zanewalker f3453cb0fc release: bump version to 0.3.0
- Refactor Redis backend connection handling and pool management
- Update algorithm implementations with improved type annotations
- Enhance config loader validation with stricter Pydantic schemas
- Improve decorator and middleware error handling
- Expand example scripts with better docstrings and usage patterns
- Add new 00_basic_usage.py example for quick start
- Reorganize examples directory structure
- Fix type annotation inconsistencies across core modules
- Update dependencies in pyproject.toml
2026-03-17 21:04:34 +00:00

289 lines
7.9 KiB
Python

"""Example of a production-ready tiered API with different rate limits per plan."""
from __future__ import annotations
from contextlib import asynccontextmanager
from dataclasses import dataclass
from enum import Enum
from typing import Any
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
from fastapi_traffic import (
Algorithm,
MemoryBackend,
RateLimiter,
RateLimitExceeded,
)
from fastapi_traffic.core.decorator import RateLimitDependency
from fastapi_traffic.core.limiter import set_limiter
DEFAULT_HOST = "127.0.0.1"
DEFAULT_PORT = 8000
backend = MemoryBackend()
limiter = RateLimiter(backend)
@asynccontextmanager
async def lifespan(_: FastAPI):
await limiter.initialize()
set_limiter(limiter)
yield
await limiter.close()
app = FastAPI(
title="Tiered API Example",
description="API with different rate limits based on subscription tier",
lifespan=lifespan,
)
class Tier(str, Enum):
FREE = "free"
STARTER = "starter"
PRO = "pro"
ENTERPRISE = "enterprise"
@dataclass
class TierConfig:
requests_per_minute: int
requests_per_hour: int
requests_per_day: int
burst_size: int
features: list[str]
# Tier configurations
TIER_CONFIGS: dict[Tier, TierConfig] = {
Tier.FREE: TierConfig(
requests_per_minute=10,
requests_per_hour=100,
requests_per_day=500,
burst_size=5,
features=["basic_api"],
),
Tier.STARTER: TierConfig(
requests_per_minute=60,
requests_per_hour=1000,
requests_per_day=10000,
burst_size=20,
features=["basic_api", "webhooks"],
),
Tier.PRO: TierConfig(
requests_per_minute=300,
requests_per_hour=10000,
requests_per_day=100000,
burst_size=50,
features=["basic_api", "webhooks", "analytics", "priority_support"],
),
Tier.ENTERPRISE: TierConfig(
requests_per_minute=1000,
requests_per_hour=50000,
requests_per_day=500000,
burst_size=200,
features=[
"basic_api",
"webhooks",
"analytics",
"priority_support",
"sla",
"custom_integrations",
],
),
}
# Simulated API key database
API_KEYS: dict[str, dict[str, Any]] = {
"free-key-123": {"tier": Tier.FREE, "user_id": "user_1"},
"starter-key-456": {"tier": Tier.STARTER, "user_id": "user_2"},
"pro-key-789": {"tier": Tier.PRO, "user_id": "user_3"},
"enterprise-key-000": {"tier": Tier.ENTERPRISE, "user_id": "user_4"},
}
@app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
api_key = request.headers.get("X-API-Key", "")
key_info = API_KEYS.get(api_key, {})
tier = key_info.get("tier", Tier.FREE)
return JSONResponse(
status_code=429,
content={
"error": "rate_limit_exceeded",
"message": exc.message,
"retry_after": exc.retry_after,
"tier": tier.value,
"upgrade_url": (
"https://example.com/pricing" if tier != Tier.ENTERPRISE else None
),
},
headers=exc.limit_info.to_headers() if exc.limit_info else {},
)
def get_api_key_info(request: Request) -> dict[str, Any]:
"""Validate API key and return info."""
api_key = request.headers.get("X-API-Key")
if not api_key:
raise HTTPException(status_code=401, detail="API key required")
key_info = API_KEYS.get(api_key)
if not key_info:
raise HTTPException(status_code=401, detail="Invalid API key")
return {"api_key": api_key, **key_info}
def get_tier_config(key_info: dict[str, Any] = Depends(get_api_key_info)) -> TierConfig:
"""Get rate limit config for user's tier."""
tier = key_info.get("tier", Tier.FREE)
return TIER_CONFIGS[tier]
# Create rate limit dependencies for each tier
tier_rate_limits: dict[Tier, RateLimitDependency] = {}
for tier, config in TIER_CONFIGS.items():
tier_rate_limits[tier] = RateLimitDependency(
limit=config.requests_per_minute,
window_size=60,
algorithm=Algorithm.TOKEN_BUCKET,
burst_size=config.burst_size,
key_prefix=f"tier_{tier.value}",
)
def api_key_extractor(request: Request) -> str:
"""Extract API key for rate limiting."""
api_key = request.headers.get("X-API-Key", "anonymous")
return f"api:{api_key}"
async def apply_tier_rate_limit(
request: Request,
key_info: dict[str, Any] = Depends(get_api_key_info),
) -> dict[str, Any]:
"""Apply rate limit based on user's tier."""
tier = key_info.get("tier", Tier.FREE)
rate_limit_dep = tier_rate_limits[tier]
rate_info = await rate_limit_dep(request)
return {
"tier": tier,
"config": TIER_CONFIGS[tier],
"rate_info": rate_info,
"key_info": key_info,
}
@app.get("/api/v1/data")
async def get_data(
_: Request,
limit_info: dict[str, Any] = Depends(apply_tier_rate_limit),
) -> dict[str, Any]:
"""Get data with tier-based rate limiting."""
return {
"data": {"items": ["item1", "item2", "item3"]},
"tier": limit_info["tier"].value,
"rate_limit": {
"limit": limit_info["rate_info"].limit,
"remaining": limit_info["rate_info"].remaining,
"reset_at": limit_info["rate_info"].reset_at,
},
}
@app.get("/api/v1/analytics")
async def get_analytics(
_: Request,
limit_info: dict[str, Any] = Depends(apply_tier_rate_limit),
) -> dict[str, Any]:
"""Analytics endpoint - requires Pro tier or higher."""
tier = limit_info["tier"]
config = limit_info["config"]
if "analytics" not in config.features:
raise HTTPException(
status_code=403,
detail=f"Analytics requires Pro tier or higher. Current tier: {tier.value}",
)
return {
"analytics": {
"total_requests": 12345,
"unique_users": 567,
"avg_response_time_ms": 45,
},
"tier": tier.value,
}
@app.get("/api/v1/tier-info")
async def get_tier_info(
key_info: dict[str, Any] = Depends(get_api_key_info),
) -> dict[str, Any]:
"""Get information about current tier and limits."""
tier = key_info.get("tier", Tier.FREE)
config = TIER_CONFIGS[tier]
return {
"tier": tier.value,
"limits": {
"requests_per_minute": config.requests_per_minute,
"requests_per_hour": config.requests_per_hour,
"requests_per_day": config.requests_per_day,
"burst_size": config.burst_size,
},
"features": config.features,
"upgrade_options": [
t.value
for t in Tier
if TIER_CONFIGS[t].requests_per_minute > config.requests_per_minute
],
}
@app.get("/pricing")
async def pricing() -> dict[str, Any]:
"""Public pricing information."""
return {
"tiers": {
tier.value: {
"requests_per_minute": config.requests_per_minute,
"requests_per_day": config.requests_per_day,
"features": config.features,
}
for tier, config in TIER_CONFIGS.items()
}
}
if __name__ == "__main__":
import argparse
import uvicorn
parser = argparse.ArgumentParser(description="Tiered API example")
parser.add_argument(
"--host",
default=DEFAULT_HOST,
help=f"Host to bind to (default: {DEFAULT_HOST})",
)
parser.add_argument(
"--port",
type=int,
default=DEFAULT_PORT,
help=f"Port to bind to (default: {DEFAULT_PORT})",
)
args = parser.parse_args()
# Test with different API keys:
# curl -H "X-API-Key: free-key-123" http://localhost:8000/api/v1/data
# curl -H "X-API-Key: pro-key-789" http://localhost:8000/api/v1/analytics
uvicorn.run(app, host=args.host, port=args.port)