Initial commit: fastapi-traffic rate limiting library
- Core rate limiting with multiple algorithms (sliding window, token bucket, etc.) - SQLite and memory backends - Decorator and dependency injection patterns - Middleware support - Example usage files
This commit is contained in:
60
examples/01_quickstart.py
Normal file
60
examples/01_quickstart.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""Quickstart example - minimal setup to get rate limiting working."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from fastapi_traffic import (
|
||||
MemoryBackend,
|
||||
RateLimitExceeded,
|
||||
RateLimiter,
|
||||
rate_limit,
|
||||
)
|
||||
from fastapi_traffic.core.limiter import set_limiter
|
||||
|
||||
# Step 1: Create a backend and limiter
|
||||
backend = MemoryBackend()
|
||||
limiter = RateLimiter(backend)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Lifespan context manager for startup/shutdown."""
|
||||
await limiter.initialize()
|
||||
set_limiter(limiter)
|
||||
yield
|
||||
await limiter.close()
|
||||
|
||||
|
||||
app = FastAPI(title="Quickstart Example", lifespan=lifespan)
|
||||
|
||||
|
||||
# Step 2: Add exception handler for rate limit errors
|
||||
@app.exception_handler(RateLimitExceeded)
|
||||
async def rate_limit_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={"error": "Too many requests", "retry_after": exc.retry_after},
|
||||
)
|
||||
|
||||
|
||||
# Step 3: Apply rate limiting to endpoints
|
||||
@app.get("/")
|
||||
@rate_limit(10, 60) # 10 requests per minute
|
||||
async def hello(request: Request) -> dict[str, str]:
|
||||
return {"message": "Hello, World!"}
|
||||
|
||||
|
||||
@app.get("/api/data")
|
||||
@rate_limit(100, 60) # 100 requests per minute
|
||||
async def get_data(request: Request) -> dict[str, str]:
|
||||
return {"data": "Some important data"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
122
examples/02_algorithms.py
Normal file
122
examples/02_algorithms.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""Examples demonstrating all available rate limiting algorithms."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from fastapi_traffic import (
|
||||
Algorithm,
|
||||
MemoryBackend,
|
||||
RateLimitExceeded,
|
||||
RateLimiter,
|
||||
rate_limit,
|
||||
)
|
||||
from fastapi_traffic.core.limiter import set_limiter
|
||||
|
||||
backend = MemoryBackend()
|
||||
limiter = RateLimiter(backend)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
await limiter.initialize()
|
||||
set_limiter(limiter)
|
||||
yield
|
||||
await limiter.close()
|
||||
|
||||
|
||||
app = FastAPI(title="Rate Limiting Algorithms", lifespan=lifespan)
|
||||
|
||||
|
||||
@app.exception_handler(RateLimitExceeded)
|
||||
async def rate_limit_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"error": "rate_limit_exceeded",
|
||||
"message": exc.message,
|
||||
"retry_after": exc.retry_after,
|
||||
},
|
||||
headers=exc.limit_info.to_headers() if exc.limit_info else {},
|
||||
)
|
||||
|
||||
|
||||
# 1. Fixed Window - Simple, resets at fixed intervals
|
||||
# Best for: Simple use cases, low memory usage
|
||||
# Drawback: Can allow 2x burst at window boundaries
|
||||
@app.get("/fixed-window")
|
||||
@rate_limit(
|
||||
limit=10,
|
||||
window_size=60,
|
||||
algorithm=Algorithm.FIXED_WINDOW,
|
||||
)
|
||||
async def fixed_window(request: Request) -> dict[str, str]:
|
||||
"""Fixed window resets counter at fixed time intervals."""
|
||||
return {"algorithm": "fixed_window", "description": "Counter resets every 60 seconds"}
|
||||
|
||||
|
||||
# 2. Sliding Window Log - Most precise
|
||||
# Best for: When accuracy is critical
|
||||
# Drawback: Higher memory usage (stores all timestamps)
|
||||
@app.get("/sliding-window")
|
||||
@rate_limit(
|
||||
limit=10,
|
||||
window_size=60,
|
||||
algorithm=Algorithm.SLIDING_WINDOW,
|
||||
)
|
||||
async def sliding_window(request: Request) -> dict[str, str]:
|
||||
"""Sliding window tracks exact timestamps for precise limiting."""
|
||||
return {"algorithm": "sliding_window", "description": "Precise tracking with timestamp log"}
|
||||
|
||||
|
||||
# 3. Sliding Window Counter - Balance of precision and efficiency
|
||||
# Best for: Most production use cases (default algorithm)
|
||||
# Combines benefits of fixed window efficiency with sliding window precision
|
||||
@app.get("/sliding-window-counter")
|
||||
@rate_limit(
|
||||
limit=10,
|
||||
window_size=60,
|
||||
algorithm=Algorithm.SLIDING_WINDOW_COUNTER,
|
||||
)
|
||||
async def sliding_window_counter(request: Request) -> dict[str, str]:
|
||||
"""Sliding window counter uses weighted counts from current and previous windows."""
|
||||
return {"algorithm": "sliding_window_counter", "description": "Efficient approximation"}
|
||||
|
||||
|
||||
# 4. Token Bucket - Allows controlled bursts
|
||||
# Best for: APIs that need to allow occasional bursts
|
||||
# Tokens refill gradually, burst_size controls max burst
|
||||
@app.get("/token-bucket")
|
||||
@rate_limit(
|
||||
limit=10,
|
||||
window_size=60,
|
||||
algorithm=Algorithm.TOKEN_BUCKET,
|
||||
burst_size=5, # Allow bursts of up to 5 requests
|
||||
)
|
||||
async def token_bucket(request: Request) -> dict[str, str]:
|
||||
"""Token bucket allows bursts up to burst_size, then refills gradually."""
|
||||
return {"algorithm": "token_bucket", "description": "Allows controlled bursts"}
|
||||
|
||||
|
||||
# 5. Leaky Bucket - Smooths out traffic
|
||||
# Best for: Protecting downstream services from bursts
|
||||
# Processes requests at a constant rate
|
||||
@app.get("/leaky-bucket")
|
||||
@rate_limit(
|
||||
limit=10,
|
||||
window_size=60,
|
||||
algorithm=Algorithm.LEAKY_BUCKET,
|
||||
burst_size=5, # Queue capacity
|
||||
)
|
||||
async def leaky_bucket(request: Request) -> dict[str, str]:
|
||||
"""Leaky bucket smooths traffic to a constant rate."""
|
||||
return {"algorithm": "leaky_bucket", "description": "Constant output rate"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
108
examples/03_backends.py
Normal file
108
examples/03_backends.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""Examples demonstrating different storage backends."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from fastapi_traffic import (
|
||||
MemoryBackend,
|
||||
RateLimitExceeded,
|
||||
RateLimiter,
|
||||
SQLiteBackend,
|
||||
rate_limit,
|
||||
)
|
||||
from fastapi_traffic.core.limiter import set_limiter
|
||||
|
||||
|
||||
# Choose backend based on environment
|
||||
def get_backend():
|
||||
"""Select appropriate backend based on environment."""
|
||||
backend_type = os.getenv("RATE_LIMIT_BACKEND", "memory")
|
||||
|
||||
if backend_type == "sqlite":
|
||||
# SQLite - Good for single-instance apps, persists across restarts
|
||||
return SQLiteBackend("rate_limits.db")
|
||||
|
||||
elif backend_type == "redis":
|
||||
# Redis - Required for distributed/multi-instance deployments
|
||||
# Requires: pip install redis
|
||||
try:
|
||||
from fastapi_traffic import RedisBackend
|
||||
import asyncio
|
||||
|
||||
async def create_redis():
|
||||
return await RedisBackend.from_url(
|
||||
os.getenv("REDIS_URL", "redis://localhost:6379/0"),
|
||||
key_prefix="myapp_ratelimit",
|
||||
)
|
||||
|
||||
return asyncio.get_event_loop().run_until_complete(create_redis())
|
||||
except ImportError:
|
||||
print("Redis not installed, falling back to memory backend")
|
||||
return MemoryBackend()
|
||||
|
||||
else:
|
||||
# Memory - Fast, but resets on restart, not shared across instances
|
||||
return MemoryBackend()
|
||||
|
||||
|
||||
backend = get_backend()
|
||||
limiter = RateLimiter(backend)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
await limiter.initialize()
|
||||
set_limiter(limiter)
|
||||
yield
|
||||
await limiter.close()
|
||||
|
||||
|
||||
app = FastAPI(title="Storage Backends Example", lifespan=lifespan)
|
||||
|
||||
|
||||
@app.exception_handler(RateLimitExceeded)
|
||||
async def rate_limit_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={"error": "rate_limit_exceeded", "retry_after": exc.retry_after},
|
||||
)
|
||||
|
||||
|
||||
@app.get("/api/resource")
|
||||
@rate_limit(100, 60)
|
||||
async def get_resource(request: Request) -> dict[str, str]:
|
||||
return {"message": "Resource data", "backend": type(backend).__name__}
|
||||
|
||||
|
||||
@app.get("/backend-info")
|
||||
async def backend_info() -> dict[str, Any]:
|
||||
"""Get information about the current backend."""
|
||||
info = {
|
||||
"backend_type": type(backend).__name__,
|
||||
"description": "",
|
||||
}
|
||||
|
||||
if isinstance(backend, MemoryBackend):
|
||||
info["description"] = "In-memory storage, fast but ephemeral"
|
||||
elif isinstance(backend, SQLiteBackend):
|
||||
info["description"] = "SQLite storage, persistent, single-instance"
|
||||
else:
|
||||
info["description"] = "Redis storage, distributed, multi-instance"
|
||||
|
||||
return info
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
# Run with different backends:
|
||||
# RATE_LIMIT_BACKEND=memory python 03_backends.py
|
||||
# RATE_LIMIT_BACKEND=sqlite python 03_backends.py
|
||||
# RATE_LIMIT_BACKEND=redis REDIS_URL=redis://localhost:6379/0 python 03_backends.py
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
153
examples/04_key_extractors.py
Normal file
153
examples/04_key_extractors.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""Examples demonstrating custom key extractors for rate limiting."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from fastapi_traffic import (
|
||||
MemoryBackend,
|
||||
RateLimitExceeded,
|
||||
RateLimiter,
|
||||
rate_limit,
|
||||
)
|
||||
from fastapi_traffic.core.limiter import set_limiter
|
||||
|
||||
backend = MemoryBackend()
|
||||
limiter = RateLimiter(backend)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
await limiter.initialize()
|
||||
set_limiter(limiter)
|
||||
yield
|
||||
await limiter.close()
|
||||
|
||||
|
||||
app = FastAPI(title="Custom Key Extractors", lifespan=lifespan)
|
||||
|
||||
|
||||
@app.exception_handler(RateLimitExceeded)
|
||||
async def rate_limit_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={"error": "rate_limit_exceeded", "retry_after": exc.retry_after},
|
||||
)
|
||||
|
||||
|
||||
# 1. Default: Rate limit by IP address
|
||||
@app.get("/by-ip")
|
||||
@rate_limit(10, 60) # Uses default IP-based key extractor
|
||||
async def by_ip(request: Request) -> dict[str, str]:
|
||||
"""Rate limited by client IP address (default behavior)."""
|
||||
return {"limited_by": "ip", "client_ip": request.client.host if request.client else "unknown"}
|
||||
|
||||
|
||||
# 2. Rate limit by API key
|
||||
def api_key_extractor(request: Request) -> str:
|
||||
"""Extract API key from header."""
|
||||
api_key = request.headers.get("X-API-Key", "anonymous")
|
||||
return f"api_key:{api_key}"
|
||||
|
||||
|
||||
@app.get("/by-api-key")
|
||||
@rate_limit(
|
||||
limit=100,
|
||||
window_size=3600, # 100 requests per hour per API key
|
||||
key_extractor=api_key_extractor,
|
||||
)
|
||||
async def by_api_key(request: Request) -> dict[str, str]:
|
||||
"""Rate limited by API key from X-API-Key header."""
|
||||
api_key = request.headers.get("X-API-Key", "anonymous")
|
||||
return {"limited_by": "api_key", "api_key": api_key}
|
||||
|
||||
|
||||
# 3. Rate limit by user ID (from JWT or session)
|
||||
def user_id_extractor(request: Request) -> str:
|
||||
"""Extract user ID from request state or header."""
|
||||
# In real apps, this would come from decoded JWT or session
|
||||
user_id = request.headers.get("X-User-ID", "anonymous")
|
||||
return f"user:{user_id}"
|
||||
|
||||
|
||||
@app.get("/by-user")
|
||||
@rate_limit(
|
||||
limit=50,
|
||||
window_size=60,
|
||||
key_extractor=user_id_extractor,
|
||||
)
|
||||
async def by_user(request: Request) -> dict[str, str]:
|
||||
"""Rate limited by user ID."""
|
||||
user_id = request.headers.get("X-User-ID", "anonymous")
|
||||
return {"limited_by": "user_id", "user_id": user_id}
|
||||
|
||||
|
||||
# 4. Rate limit by endpoint + IP (separate limits per endpoint)
|
||||
def endpoint_ip_extractor(request: Request) -> str:
|
||||
"""Combine endpoint path with IP for per-endpoint limits."""
|
||||
ip = request.client.host if request.client else "unknown"
|
||||
path = request.url.path
|
||||
return f"endpoint:{path}:ip:{ip}"
|
||||
|
||||
|
||||
@app.get("/endpoint-specific")
|
||||
@rate_limit(
|
||||
limit=5,
|
||||
window_size=60,
|
||||
key_extractor=endpoint_ip_extractor,
|
||||
)
|
||||
async def endpoint_specific(request: Request) -> dict[str, str]:
|
||||
"""Each endpoint has its own rate limit counter."""
|
||||
return {"limited_by": "endpoint+ip"}
|
||||
|
||||
|
||||
# 5. Rate limit by organization/tenant (multi-tenant apps)
|
||||
def tenant_extractor(request: Request) -> str:
|
||||
"""Extract tenant from subdomain or header."""
|
||||
# Could also parse from subdomain: tenant.example.com
|
||||
tenant = request.headers.get("X-Tenant-ID", "default")
|
||||
return f"tenant:{tenant}"
|
||||
|
||||
|
||||
@app.get("/by-tenant")
|
||||
@rate_limit(
|
||||
limit=1000,
|
||||
window_size=3600, # 1000 requests per hour per tenant
|
||||
key_extractor=tenant_extractor,
|
||||
)
|
||||
async def by_tenant(request: Request) -> dict[str, str]:
|
||||
"""Rate limited by tenant/organization."""
|
||||
tenant = request.headers.get("X-Tenant-ID", "default")
|
||||
return {"limited_by": "tenant", "tenant_id": tenant}
|
||||
|
||||
|
||||
# 6. Composite key: User + Action type
|
||||
def user_action_extractor(request: Request) -> str:
|
||||
"""Rate limit specific actions per user."""
|
||||
user_id = request.headers.get("X-User-ID", "anonymous")
|
||||
action = request.query_params.get("action", "default")
|
||||
return f"user:{user_id}:action:{action}"
|
||||
|
||||
|
||||
@app.get("/user-action")
|
||||
@rate_limit(
|
||||
limit=10,
|
||||
window_size=60,
|
||||
key_extractor=user_action_extractor,
|
||||
)
|
||||
async def user_action(
|
||||
request: Request,
|
||||
action: str = "default",
|
||||
) -> dict[str, str]:
|
||||
"""Rate limited by user + action combination."""
|
||||
user_id = request.headers.get("X-User-ID", "anonymous")
|
||||
return {"limited_by": "user+action", "user_id": user_id, "action": action}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
109
examples/05_middleware.py
Normal file
109
examples/05_middleware.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""Examples demonstrating middleware-based rate limiting."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
|
||||
from fastapi_traffic import MemoryBackend
|
||||
from fastapi_traffic.middleware import RateLimitMiddleware
|
||||
|
||||
# Alternative middleware options (uncomment to use):
|
||||
# from fastapi_traffic.middleware import SlidingWindowMiddleware
|
||||
# from fastapi_traffic.middleware import TokenBucketMiddleware
|
||||
|
||||
app = FastAPI(title="Middleware Rate Limiting")
|
||||
|
||||
|
||||
# Custom key extractor for middleware
|
||||
def get_client_identifier(request: Request) -> str:
|
||||
"""Extract client identifier from request."""
|
||||
# Check for API key first
|
||||
api_key = request.headers.get("X-API-Key")
|
||||
if api_key:
|
||||
return f"api_key:{api_key}"
|
||||
|
||||
# Fall back to IP
|
||||
if request.client:
|
||||
return f"ip:{request.client.host}"
|
||||
|
||||
return "unknown"
|
||||
|
||||
|
||||
# Option 1: Basic middleware with defaults
|
||||
# Uncomment to use:
|
||||
# app.add_middleware(
|
||||
# RateLimitMiddleware,
|
||||
# limit=100,
|
||||
# window_size=60,
|
||||
# )
|
||||
|
||||
# Option 2: Middleware with custom configuration
|
||||
app.add_middleware(
|
||||
RateLimitMiddleware,
|
||||
limit=100,
|
||||
window_size=60,
|
||||
backend=MemoryBackend(),
|
||||
key_prefix="global",
|
||||
include_headers=True,
|
||||
error_message="You have exceeded the rate limit. Please slow down.",
|
||||
status_code=429,
|
||||
skip_on_error=True, # Don't block requests if backend fails
|
||||
exempt_paths={"/health", "/docs", "/openapi.json", "/redoc"},
|
||||
exempt_ips={"127.0.0.1"}, # Exempt localhost
|
||||
key_extractor=get_client_identifier,
|
||||
)
|
||||
|
||||
|
||||
# Option 3: Convenience middleware for specific algorithms
|
||||
# SlidingWindowMiddleware - precise rate limiting
|
||||
# app.add_middleware(
|
||||
# SlidingWindowMiddleware,
|
||||
# limit=100,
|
||||
# window_size=60,
|
||||
# )
|
||||
|
||||
# TokenBucketMiddleware - allows bursts
|
||||
# app.add_middleware(
|
||||
# TokenBucketMiddleware,
|
||||
# limit=100,
|
||||
# window_size=60,
|
||||
# )
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root() -> dict[str, str]:
|
||||
"""Root endpoint - rate limited by middleware."""
|
||||
return {"message": "Hello, World!"}
|
||||
|
||||
|
||||
@app.get("/api/data")
|
||||
async def get_data() -> dict[str, str]:
|
||||
"""API endpoint - rate limited by middleware."""
|
||||
return {"data": "Some important data"}
|
||||
|
||||
|
||||
@app.get("/api/users")
|
||||
async def get_users() -> dict[str, list[str]]:
|
||||
"""Users endpoint - rate limited by middleware."""
|
||||
return {"users": ["alice", "bob", "charlie"]}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> dict[str, str]:
|
||||
"""Health check - exempt from rate limiting."""
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
@app.get("/docs-info")
|
||||
async def docs_info() -> dict[str, str]:
|
||||
"""Info about documentation endpoints."""
|
||||
return {
|
||||
"message": "Visit /docs for Swagger UI or /redoc for ReDoc",
|
||||
"note": "These endpoints are exempt from rate limiting",
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
221
examples/06_dependency_injection.py
Normal file
221
examples/06_dependency_injection.py
Normal file
@@ -0,0 +1,221 @@
|
||||
"""Examples demonstrating rate limiting with FastAPI dependency injection."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Depends, FastAPI, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from fastapi_traffic import (
|
||||
MemoryBackend,
|
||||
RateLimitExceeded,
|
||||
RateLimiter,
|
||||
)
|
||||
from fastapi_traffic.core.decorator import RateLimitDependency
|
||||
from fastapi_traffic.core.limiter import set_limiter
|
||||
|
||||
backend = MemoryBackend()
|
||||
limiter = RateLimiter(backend)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Lifespan context manager for startup/shutdown."""
|
||||
await limiter.initialize()
|
||||
set_limiter(limiter)
|
||||
yield
|
||||
await limiter.close()
|
||||
|
||||
|
||||
app = FastAPI(title="Dependency Injection Example", lifespan=lifespan)
|
||||
|
||||
|
||||
@app.exception_handler(RateLimitExceeded)
|
||||
async def rate_limit_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={"error": "rate_limit_exceeded", "retry_after": exc.retry_after},
|
||||
)
|
||||
|
||||
|
||||
# 1. Basic dependency - rate limit info available in endpoint
|
||||
basic_rate_limit = RateLimitDependency(limit=10, window_size=60)
|
||||
|
||||
|
||||
@app.get("/basic")
|
||||
async def basic_endpoint(
|
||||
request: Request,
|
||||
rate_info: Any = Depends(basic_rate_limit),
|
||||
) -> dict[str, Any]:
|
||||
"""Access rate limit info in your endpoint logic."""
|
||||
return {
|
||||
"message": "Success",
|
||||
"rate_limit": {
|
||||
"limit": rate_info.limit,
|
||||
"remaining": rate_info.remaining,
|
||||
"reset_at": rate_info.reset_at,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# 2. Different limits for different user tiers
|
||||
def get_user_tier(request: Request) -> str:
|
||||
"""Get user tier from header (in real app, from JWT/database)."""
|
||||
return request.headers.get("X-User-Tier", "free")
|
||||
|
||||
|
||||
free_tier_limit = RateLimitDependency(
|
||||
limit=10,
|
||||
window_size=60,
|
||||
key_prefix="free",
|
||||
)
|
||||
|
||||
pro_tier_limit = RateLimitDependency(
|
||||
limit=100,
|
||||
window_size=60,
|
||||
key_prefix="pro",
|
||||
)
|
||||
|
||||
enterprise_tier_limit = RateLimitDependency(
|
||||
limit=1000,
|
||||
window_size=60,
|
||||
key_prefix="enterprise",
|
||||
)
|
||||
|
||||
|
||||
async def tiered_rate_limit(
|
||||
request: Request,
|
||||
tier: str = Depends(get_user_tier),
|
||||
) -> Any:
|
||||
"""Apply different rate limits based on user tier."""
|
||||
if tier == "enterprise":
|
||||
return await enterprise_tier_limit(request)
|
||||
elif tier == "pro":
|
||||
return await pro_tier_limit(request)
|
||||
else:
|
||||
return await free_tier_limit(request)
|
||||
|
||||
|
||||
@app.get("/tiered")
|
||||
async def tiered_endpoint(
|
||||
request: Request,
|
||||
rate_info: Any = Depends(tiered_rate_limit),
|
||||
) -> dict[str, Any]:
|
||||
"""Endpoint with tier-based rate limiting."""
|
||||
tier = get_user_tier(request)
|
||||
return {
|
||||
"message": "Success",
|
||||
"tier": tier,
|
||||
"rate_limit": {
|
||||
"limit": rate_info.limit,
|
||||
"remaining": rate_info.remaining,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# 3. Conditional rate limiting based on request properties
|
||||
def api_key_extractor(request: Request) -> str:
|
||||
"""Extract API key for rate limiting."""
|
||||
api_key = request.headers.get("X-API-Key", "anonymous")
|
||||
return f"api:{api_key}"
|
||||
|
||||
|
||||
api_rate_limit = RateLimitDependency(
|
||||
limit=100,
|
||||
window_size=3600,
|
||||
key_extractor=api_key_extractor,
|
||||
)
|
||||
|
||||
|
||||
@app.get("/api/resource")
|
||||
async def api_resource(
|
||||
request: Request,
|
||||
rate_info: Any = Depends(api_rate_limit),
|
||||
) -> dict[str, Any]:
|
||||
"""API endpoint with per-API-key rate limiting."""
|
||||
return {
|
||||
"data": "Resource data",
|
||||
"requests_remaining": rate_info.remaining,
|
||||
}
|
||||
|
||||
|
||||
# 4. Combine multiple rate limits (e.g., per-minute AND per-hour)
|
||||
per_minute_limit = RateLimitDependency(
|
||||
limit=10,
|
||||
window_size=60,
|
||||
key_prefix="minute",
|
||||
)
|
||||
|
||||
per_hour_limit = RateLimitDependency(
|
||||
limit=100,
|
||||
window_size=3600,
|
||||
key_prefix="hour",
|
||||
)
|
||||
|
||||
|
||||
async def combined_rate_limit(
|
||||
request: Request,
|
||||
minute_info: Any = Depends(per_minute_limit),
|
||||
hour_info: Any = Depends(per_hour_limit),
|
||||
) -> dict[str, Any]:
|
||||
"""Apply both per-minute and per-hour limits."""
|
||||
return {
|
||||
"minute": {
|
||||
"limit": minute_info.limit,
|
||||
"remaining": minute_info.remaining,
|
||||
},
|
||||
"hour": {
|
||||
"limit": hour_info.limit,
|
||||
"remaining": hour_info.remaining,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@app.get("/combined")
|
||||
async def combined_endpoint(
|
||||
request: Request,
|
||||
rate_info: dict[str, Any] = Depends(combined_rate_limit),
|
||||
) -> dict[str, Any]:
|
||||
"""Endpoint with multiple rate limit tiers."""
|
||||
return {
|
||||
"message": "Success",
|
||||
"rate_limits": rate_info,
|
||||
}
|
||||
|
||||
|
||||
# 5. Rate limit with custom exemption logic
|
||||
def is_internal_request(request: Request) -> bool:
|
||||
"""Check if request is from internal service."""
|
||||
internal_token = request.headers.get("X-Internal-Token")
|
||||
return internal_token == "internal-secret-token"
|
||||
|
||||
|
||||
internal_exempt_limit = RateLimitDependency(
|
||||
limit=10,
|
||||
window_size=60,
|
||||
exempt_when=is_internal_request,
|
||||
)
|
||||
|
||||
|
||||
@app.get("/internal-exempt")
|
||||
async def internal_exempt_endpoint(
|
||||
request: Request,
|
||||
rate_info: Any = Depends(internal_exempt_limit),
|
||||
) -> dict[str, Any]:
|
||||
"""Internal requests are exempt from rate limiting."""
|
||||
is_internal = is_internal_request(request)
|
||||
return {
|
||||
"message": "Success",
|
||||
"is_internal": is_internal,
|
||||
"rate_limit": None if is_internal else {
|
||||
"remaining": rate_info.remaining,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
197
examples/07_redis_distributed.py
Normal file
197
examples/07_redis_distributed.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""Example demonstrating Redis backend for distributed rate limiting.
|
||||
|
||||
This example shows how to use Redis for rate limiting across multiple
|
||||
application instances (e.g., in a Kubernetes deployment or load-balanced setup).
|
||||
|
||||
Requirements:
|
||||
pip install redis
|
||||
|
||||
Environment variables:
|
||||
REDIS_URL: Redis connection URL (default: redis://localhost:6379/0)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import Depends, FastAPI, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi_traffic import (
|
||||
Algorithm,
|
||||
MemoryBackend,
|
||||
RateLimitExceeded,
|
||||
RateLimiter,
|
||||
rate_limit,
|
||||
)
|
||||
from fastapi_traffic.core.limiter import set_limiter
|
||||
from fastapi_traffic.backends.redis import RedisBackend
|
||||
|
||||
|
||||
async def create_redis_backend():
|
||||
"""Create Redis backend with fallback to memory."""
|
||||
try:
|
||||
from fastapi_traffic import RedisBackend
|
||||
|
||||
redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0")
|
||||
backend = await RedisBackend.from_url(
|
||||
redis_url,
|
||||
key_prefix="myapp",
|
||||
)
|
||||
|
||||
# Verify connection
|
||||
if await backend.ping():
|
||||
print(f"Connected to Redis at {redis_url}")
|
||||
return backend
|
||||
else:
|
||||
print("Redis ping failed, falling back to memory backend")
|
||||
return MemoryBackend()
|
||||
|
||||
except ImportError:
|
||||
print("Redis package not installed. Install with: pip install redis")
|
||||
print("Falling back to memory backend")
|
||||
return MemoryBackend()
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to connect to Redis: {e}")
|
||||
print("Falling back to memory backend")
|
||||
return MemoryBackend()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Lifespan context manager for startup/shutdown."""
|
||||
app.state.backend = await create_redis_backend()
|
||||
app.state.limiter = RateLimiter(app.state.backend)
|
||||
await app.state.limiter.initialize()
|
||||
set_limiter(app.state.limiter)
|
||||
|
||||
yield
|
||||
|
||||
await app.state.limiter.close()
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="Distributed Rate Limiting with Redis",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
|
||||
def get_backend(request: Request) -> RedisBackend | MemoryBackend:
|
||||
"""Dependency to get the rate limiting backend."""
|
||||
return request.app.state.backend
|
||||
|
||||
|
||||
def get_limiter(request: Request) -> RateLimiter:
|
||||
"""Dependency to get the rate limiter."""
|
||||
return request.app.state.limiter
|
||||
|
||||
|
||||
BackendDep = Annotated[RedisBackend | MemoryBackend, Depends(get_backend)]
|
||||
LimiterDep = Annotated[RateLimiter, Depends(get_limiter)]
|
||||
|
||||
|
||||
@app.exception_handler(RateLimitExceeded)
|
||||
async def rate_limit_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"error": "rate_limit_exceeded",
|
||||
"message": exc.message,
|
||||
"retry_after": exc.retry_after,
|
||||
},
|
||||
headers=exc.limit_info.to_headers() if exc.limit_info else {},
|
||||
)
|
||||
|
||||
|
||||
# Rate limits are shared across all instances when using Redis
|
||||
@app.get("/api/shared-limit")
|
||||
@rate_limit(
|
||||
limit=100,
|
||||
window_size=60,
|
||||
key_prefix="shared",
|
||||
)
|
||||
async def shared_limit(request: Request) -> dict[str, str]:
|
||||
"""This rate limit is shared across all application instances."""
|
||||
return {
|
||||
"message": "Success",
|
||||
"note": "Rate limit counter is shared via Redis",
|
||||
}
|
||||
|
||||
|
||||
# Per-user limits also work across instances
|
||||
def user_extractor(request: Request) -> str:
|
||||
user_id = request.headers.get("X-User-ID", "anonymous")
|
||||
return f"user:{user_id}"
|
||||
|
||||
|
||||
@app.get("/api/user-limit")
|
||||
@rate_limit(
|
||||
limit=50,
|
||||
window_size=60,
|
||||
key_extractor=user_extractor,
|
||||
key_prefix="user_api",
|
||||
)
|
||||
async def user_limit(request: Request) -> dict[str, str]:
|
||||
"""Per-user rate limit shared across instances."""
|
||||
user_id = request.headers.get("X-User-ID", "anonymous")
|
||||
return {
|
||||
"message": "Success",
|
||||
"user_id": user_id,
|
||||
}
|
||||
|
||||
|
||||
# Token bucket works well with Redis for burst handling
|
||||
@app.get("/api/burst-allowed")
|
||||
@rate_limit(
|
||||
limit=100,
|
||||
window_size=60,
|
||||
algorithm=Algorithm.TOKEN_BUCKET,
|
||||
burst_size=20,
|
||||
key_prefix="burst",
|
||||
)
|
||||
async def burst_allowed(request: Request) -> dict[str, str]:
|
||||
"""Token bucket with Redis allows controlled bursts across instances."""
|
||||
return {"message": "Burst request successful"}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health(backend: BackendDep) -> dict[str, object]:
|
||||
"""Health check with Redis status."""
|
||||
redis_healthy = False
|
||||
backend_type = type(backend).__name__
|
||||
|
||||
if hasattr(backend, "ping"):
|
||||
try:
|
||||
redis_healthy = await backend.ping()
|
||||
except Exception:
|
||||
redis_healthy = False
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"backend": backend_type,
|
||||
"redis_connected": redis_healthy,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/stats")
|
||||
async def stats(backend: BackendDep) -> dict[str, object]:
|
||||
"""Get rate limiting statistics from Redis."""
|
||||
if hasattr(backend, "get_stats"):
|
||||
try:
|
||||
return await backend.get_stats()
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
return {"message": "Stats not available for this backend"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
# Run multiple instances on different ports to test distributed limiting:
|
||||
# REDIS_URL=redis://localhost:6379/0 python 07_redis_distributed.py
|
||||
# In another terminal:
|
||||
# uvicorn 07_redis_distributed:app --port 8001
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
256
examples/08_tiered_api.py
Normal file
256
examples/08_tiered_api.py
Normal file
@@ -0,0 +1,256 @@
|
||||
"""Example of a production-ready tiered API with different rate limits per plan."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Depends, FastAPI, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from fastapi_traffic import (
|
||||
Algorithm,
|
||||
MemoryBackend,
|
||||
RateLimitExceeded,
|
||||
RateLimiter,
|
||||
)
|
||||
from fastapi_traffic.core.decorator import RateLimitDependency
|
||||
from fastapi_traffic.core.limiter import set_limiter
|
||||
|
||||
backend = MemoryBackend()
|
||||
limiter = RateLimiter(backend)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
await limiter.initialize()
|
||||
set_limiter(limiter)
|
||||
yield
|
||||
await limiter.close()
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="Tiered API Example",
|
||||
description="API with different rate limits based on subscription tier",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
|
||||
class Tier(str, Enum):
|
||||
FREE = "free"
|
||||
STARTER = "starter"
|
||||
PRO = "pro"
|
||||
ENTERPRISE = "enterprise"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TierConfig:
|
||||
requests_per_minute: int
|
||||
requests_per_hour: int
|
||||
requests_per_day: int
|
||||
burst_size: int
|
||||
features: list[str]
|
||||
|
||||
|
||||
# Tier configurations
|
||||
TIER_CONFIGS: dict[Tier, TierConfig] = {
|
||||
Tier.FREE: TierConfig(
|
||||
requests_per_minute=10,
|
||||
requests_per_hour=100,
|
||||
requests_per_day=500,
|
||||
burst_size=5,
|
||||
features=["basic_api"],
|
||||
),
|
||||
Tier.STARTER: TierConfig(
|
||||
requests_per_minute=60,
|
||||
requests_per_hour=1000,
|
||||
requests_per_day=10000,
|
||||
burst_size=20,
|
||||
features=["basic_api", "webhooks"],
|
||||
),
|
||||
Tier.PRO: TierConfig(
|
||||
requests_per_minute=300,
|
||||
requests_per_hour=10000,
|
||||
requests_per_day=100000,
|
||||
burst_size=50,
|
||||
features=["basic_api", "webhooks", "analytics", "priority_support"],
|
||||
),
|
||||
Tier.ENTERPRISE: TierConfig(
|
||||
requests_per_minute=1000,
|
||||
requests_per_hour=50000,
|
||||
requests_per_day=500000,
|
||||
burst_size=200,
|
||||
features=["basic_api", "webhooks", "analytics", "priority_support", "sla", "custom_integrations"],
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
# Simulated API key database
|
||||
API_KEYS: dict[str, dict[str, Any]] = {
|
||||
"free-key-123": {"tier": Tier.FREE, "user_id": "user_1"},
|
||||
"starter-key-456": {"tier": Tier.STARTER, "user_id": "user_2"},
|
||||
"pro-key-789": {"tier": Tier.PRO, "user_id": "user_3"},
|
||||
"enterprise-key-000": {"tier": Tier.ENTERPRISE, "user_id": "user_4"},
|
||||
}
|
||||
|
||||
|
||||
@app.exception_handler(RateLimitExceeded)
|
||||
async def rate_limit_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
|
||||
api_key = request.headers.get("X-API-Key", "")
|
||||
key_info = API_KEYS.get(api_key, {})
|
||||
tier = key_info.get("tier", Tier.FREE)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"error": "rate_limit_exceeded",
|
||||
"message": exc.message,
|
||||
"retry_after": exc.retry_after,
|
||||
"tier": tier.value,
|
||||
"upgrade_url": "https://example.com/pricing" if tier != Tier.ENTERPRISE else None,
|
||||
},
|
||||
headers=exc.limit_info.to_headers() if exc.limit_info else {},
|
||||
)
|
||||
|
||||
|
||||
def get_api_key_info(request: Request) -> dict[str, Any]:
|
||||
"""Validate API key and return info."""
|
||||
api_key = request.headers.get("X-API-Key")
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=401, detail="API key required")
|
||||
|
||||
key_info = API_KEYS.get(api_key)
|
||||
if not key_info:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
return {"api_key": api_key, **key_info}
|
||||
|
||||
|
||||
def get_tier_config(key_info: dict[str, Any] = Depends(get_api_key_info)) -> TierConfig:
|
||||
"""Get rate limit config for user's tier."""
|
||||
tier = key_info.get("tier", Tier.FREE)
|
||||
return TIER_CONFIGS[tier]
|
||||
|
||||
|
||||
# Create rate limit dependencies for each tier
|
||||
tier_rate_limits: dict[Tier, RateLimitDependency] = {}
|
||||
for tier, config in TIER_CONFIGS.items():
|
||||
tier_rate_limits[tier] = RateLimitDependency(
|
||||
limit=config.requests_per_minute,
|
||||
window_size=60,
|
||||
algorithm=Algorithm.TOKEN_BUCKET,
|
||||
burst_size=config.burst_size,
|
||||
key_prefix=f"tier_{tier.value}",
|
||||
)
|
||||
|
||||
|
||||
def api_key_extractor(request: Request) -> str:
|
||||
"""Extract API key for rate limiting."""
|
||||
api_key = request.headers.get("X-API-Key", "anonymous")
|
||||
return f"api:{api_key}"
|
||||
|
||||
|
||||
async def apply_tier_rate_limit(
|
||||
request: Request,
|
||||
key_info: dict[str, Any] = Depends(get_api_key_info),
|
||||
) -> dict[str, Any]:
|
||||
"""Apply rate limit based on user's tier."""
|
||||
tier = key_info.get("tier", Tier.FREE)
|
||||
rate_limit_dep = tier_rate_limits[tier]
|
||||
rate_info = await rate_limit_dep(request)
|
||||
|
||||
return {
|
||||
"tier": tier,
|
||||
"config": TIER_CONFIGS[tier],
|
||||
"rate_info": rate_info,
|
||||
"key_info": key_info,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/v1/data")
|
||||
async def get_data(
|
||||
request: Request,
|
||||
limit_info: dict[str, Any] = Depends(apply_tier_rate_limit),
|
||||
) -> dict[str, Any]:
|
||||
"""Get data with tier-based rate limiting."""
|
||||
return {
|
||||
"data": {"items": ["item1", "item2", "item3"]},
|
||||
"tier": limit_info["tier"].value,
|
||||
"rate_limit": {
|
||||
"limit": limit_info["rate_info"].limit,
|
||||
"remaining": limit_info["rate_info"].remaining,
|
||||
"reset_at": limit_info["rate_info"].reset_at,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/v1/analytics")
|
||||
async def get_analytics(
|
||||
request: Request,
|
||||
limit_info: dict[str, Any] = Depends(apply_tier_rate_limit),
|
||||
) -> dict[str, Any]:
|
||||
"""Analytics endpoint - requires Pro tier or higher."""
|
||||
tier = limit_info["tier"]
|
||||
config = limit_info["config"]
|
||||
|
||||
if "analytics" not in config.features:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"Analytics requires Pro tier or higher. Current tier: {tier.value}",
|
||||
)
|
||||
|
||||
return {
|
||||
"analytics": {
|
||||
"total_requests": 12345,
|
||||
"unique_users": 567,
|
||||
"avg_response_time_ms": 45,
|
||||
},
|
||||
"tier": tier.value,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/v1/tier-info")
|
||||
async def get_tier_info(
|
||||
key_info: dict[str, Any] = Depends(get_api_key_info),
|
||||
) -> dict[str, Any]:
|
||||
"""Get information about current tier and limits."""
|
||||
tier = key_info.get("tier", Tier.FREE)
|
||||
config = TIER_CONFIGS[tier]
|
||||
|
||||
return {
|
||||
"tier": tier.value,
|
||||
"limits": {
|
||||
"requests_per_minute": config.requests_per_minute,
|
||||
"requests_per_hour": config.requests_per_hour,
|
||||
"requests_per_day": config.requests_per_day,
|
||||
"burst_size": config.burst_size,
|
||||
},
|
||||
"features": config.features,
|
||||
"upgrade_options": [t.value for t in Tier if TIER_CONFIGS[t].requests_per_minute > config.requests_per_minute],
|
||||
}
|
||||
|
||||
|
||||
@app.get("/pricing")
|
||||
async def pricing() -> dict[str, Any]:
|
||||
"""Public pricing information."""
|
||||
return {
|
||||
"tiers": {
|
||||
tier.value: {
|
||||
"requests_per_minute": config.requests_per_minute,
|
||||
"requests_per_day": config.requests_per_day,
|
||||
"features": config.features,
|
||||
}
|
||||
for tier, config in TIER_CONFIGS.items()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
# Test with different API keys:
|
||||
# curl -H "X-API-Key: free-key-123" http://localhost:8000/api/v1/data
|
||||
# curl -H "X-API-Key: pro-key-789" http://localhost:8000/api/v1/analytics
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
208
examples/09_custom_responses.py
Normal file
208
examples/09_custom_responses.py
Normal file
@@ -0,0 +1,208 @@
|
||||
"""Examples demonstrating custom rate limit responses and callbacks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import HTMLResponse, JSONResponse, PlainTextResponse
|
||||
|
||||
from fastapi_traffic import (
|
||||
MemoryBackend,
|
||||
RateLimitExceeded,
|
||||
RateLimiter,
|
||||
rate_limit,
|
||||
)
|
||||
from fastapi_traffic.core.limiter import set_limiter
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
backend = MemoryBackend()
|
||||
limiter = RateLimiter(backend)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
await limiter.initialize()
|
||||
set_limiter(limiter)
|
||||
yield
|
||||
await limiter.close()
|
||||
|
||||
|
||||
app = FastAPI(title="Custom Responses Example", lifespan=lifespan)
|
||||
|
||||
|
||||
# 1. Standard JSON error response
|
||||
@app.exception_handler(RateLimitExceeded)
|
||||
async def json_rate_limit_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
|
||||
"""Standard JSON response for API clients."""
|
||||
headers = exc.limit_info.to_headers() if exc.limit_info else {}
|
||||
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"error": {
|
||||
"code": "RATE_LIMIT_EXCEEDED",
|
||||
"message": exc.message,
|
||||
"retry_after_seconds": exc.retry_after,
|
||||
"documentation_url": "https://docs.example.com/rate-limits",
|
||||
},
|
||||
"request_id": request.headers.get("X-Request-ID", "unknown"),
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
|
||||
# 2. Callback for logging/monitoring when requests are blocked
|
||||
async def log_blocked_request(request: Request, info: Any) -> None:
|
||||
"""Log blocked requests for monitoring."""
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
path = request.url.path
|
||||
user_agent = request.headers.get("User-Agent", "unknown")
|
||||
|
||||
logger.warning(
|
||||
"Rate limit exceeded: ip=%s path=%s user_agent=%s remaining=%s",
|
||||
client_ip,
|
||||
path,
|
||||
user_agent,
|
||||
info.remaining if info else "unknown",
|
||||
)
|
||||
|
||||
# In production, you might:
|
||||
# - Send to metrics system (Prometheus, DataDog, etc.)
|
||||
# - Trigger alerts for suspicious patterns
|
||||
# - Update a blocklist for repeat offenders
|
||||
|
||||
|
||||
@app.get("/api/monitored")
|
||||
@rate_limit(
|
||||
limit=5,
|
||||
window_size=60,
|
||||
on_blocked=log_blocked_request,
|
||||
)
|
||||
async def monitored_endpoint(request: Request) -> dict[str, str]:
|
||||
"""Endpoint with blocked request logging."""
|
||||
return {"message": "Success"}
|
||||
|
||||
|
||||
# 3. Custom error messages per endpoint
|
||||
@app.get("/api/search")
|
||||
@rate_limit(
|
||||
limit=10,
|
||||
window_size=60,
|
||||
error_message="Search rate limit exceeded. Please wait before searching again.",
|
||||
)
|
||||
async def search_endpoint(request: Request, q: str = "") -> dict[str, Any]:
|
||||
"""Search with custom error message."""
|
||||
return {"query": q, "results": []}
|
||||
|
||||
|
||||
@app.get("/api/upload")
|
||||
@rate_limit(
|
||||
limit=5,
|
||||
window_size=300, # 5 uploads per 5 minutes
|
||||
error_message="Upload limit reached. You can upload 5 files every 5 minutes.",
|
||||
)
|
||||
async def upload_endpoint(request: Request) -> dict[str, str]:
|
||||
"""Upload with custom error message."""
|
||||
return {"message": "Upload successful"}
|
||||
|
||||
|
||||
# 4. Different response formats based on Accept header
|
||||
@app.get("/api/flexible")
|
||||
@rate_limit(limit=10, window_size=60)
|
||||
async def flexible_endpoint(request: Request) -> dict[str, str]:
|
||||
"""Endpoint that returns different formats."""
|
||||
return {"message": "Success", "data": "Some data"}
|
||||
|
||||
|
||||
# Custom exception handler that respects Accept header
|
||||
@app.exception_handler(RateLimitExceeded)
|
||||
async def flexible_rate_limit_handler(request: Request, exc: RateLimitExceeded):
|
||||
"""Return response in format matching Accept header."""
|
||||
accept = request.headers.get("Accept", "application/json")
|
||||
headers = exc.limit_info.to_headers() if exc.limit_info else {}
|
||||
|
||||
if "text/html" in accept:
|
||||
html_content = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head><title>Rate Limit Exceeded</title></head>
|
||||
<body>
|
||||
<h1>429 - Too Many Requests</h1>
|
||||
<p>{exc.message}</p>
|
||||
<p>Please try again in {exc.retry_after:.0f} seconds.</p>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
return HTMLResponse(content=html_content, status_code=429, headers=headers)
|
||||
|
||||
elif "text/plain" in accept:
|
||||
return PlainTextResponse(
|
||||
content=f"Rate limit exceeded. Retry after {exc.retry_after:.0f} seconds.",
|
||||
status_code=429,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"error": "rate_limit_exceeded",
|
||||
"message": exc.message,
|
||||
"retry_after": exc.retry_after,
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
|
||||
# 5. Include helpful information in response headers
|
||||
@app.get("/api/verbose-headers")
|
||||
@rate_limit(
|
||||
limit=10,
|
||||
window_size=60,
|
||||
include_headers=True, # Includes X-RateLimit-* headers
|
||||
)
|
||||
async def verbose_headers_endpoint(request: Request) -> dict[str, Any]:
|
||||
"""Response includes detailed rate limit headers."""
|
||||
return {
|
||||
"message": "Check response headers for rate limit info",
|
||||
"headers_included": [
|
||||
"X-RateLimit-Limit",
|
||||
"X-RateLimit-Remaining",
|
||||
"X-RateLimit-Reset",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# 6. Graceful degradation - return cached/stale data instead of error
|
||||
cached_data = {"data": "Cached response", "cached_at": datetime.now(timezone.utc).isoformat()}
|
||||
|
||||
|
||||
async def return_cached_on_limit(request: Request, info: Any) -> None:
|
||||
"""Log when rate limited (callback doesn't prevent exception)."""
|
||||
logger.info("Returning cached data due to rate limit")
|
||||
# This callback is called when blocked, but doesn't prevent the exception
|
||||
# To actually return cached data, you'd need custom middleware
|
||||
|
||||
|
||||
@app.get("/api/graceful")
|
||||
@rate_limit(
|
||||
limit=5,
|
||||
window_size=60,
|
||||
on_blocked=return_cached_on_limit,
|
||||
)
|
||||
async def graceful_endpoint(request: Request) -> dict[str, str]:
|
||||
"""Endpoint with graceful degradation."""
|
||||
return {"message": "Fresh data", "timestamp": datetime.now(timezone.utc).isoformat()}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
325
examples/10_advanced_patterns.py
Normal file
325
examples/10_advanced_patterns.py
Normal file
@@ -0,0 +1,325 @@
|
||||
"""Advanced patterns and real-world use cases for rate limiting."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Depends, FastAPI, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from fastapi_traffic import (
|
||||
Algorithm,
|
||||
MemoryBackend,
|
||||
RateLimitExceeded,
|
||||
RateLimiter,
|
||||
rate_limit,
|
||||
)
|
||||
from fastapi_traffic.core.decorator import RateLimitDependency
|
||||
from fastapi_traffic.core.limiter import set_limiter
|
||||
|
||||
backend = MemoryBackend()
|
||||
limiter = RateLimiter(backend)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
await limiter.initialize()
|
||||
set_limiter(limiter)
|
||||
yield
|
||||
await limiter.close()
|
||||
|
||||
|
||||
app = FastAPI(title="Advanced Patterns", lifespan=lifespan)
|
||||
|
||||
|
||||
@app.exception_handler(RateLimitExceeded)
|
||||
async def rate_limit_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={"error": "rate_limit_exceeded", "retry_after": exc.retry_after},
|
||||
headers=exc.limit_info.to_headers() if exc.limit_info else {},
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Pattern 1: Cost-based rate limiting
|
||||
# Different operations consume different amounts of quota
|
||||
# =============================================================================
|
||||
|
||||
@app.get("/api/list")
|
||||
@rate_limit(limit=100, window_size=60, cost=1)
|
||||
async def list_items(request: Request) -> dict[str, Any]:
|
||||
"""Cheap operation - costs 1 token."""
|
||||
return {"items": ["a", "b", "c"], "cost": 1}
|
||||
|
||||
|
||||
@app.get("/api/details/{item_id}")
|
||||
@rate_limit(limit=100, window_size=60, cost=5)
|
||||
async def get_details(request: Request, item_id: str) -> dict[str, Any]:
|
||||
"""Medium operation - costs 5 tokens."""
|
||||
return {"item_id": item_id, "details": "...", "cost": 5}
|
||||
|
||||
|
||||
@app.post("/api/generate")
|
||||
@rate_limit(limit=100, window_size=60, cost=20)
|
||||
async def generate_content(request: Request) -> dict[str, Any]:
|
||||
"""Expensive operation - costs 20 tokens."""
|
||||
return {"generated": "AI-generated content...", "cost": 20}
|
||||
|
||||
|
||||
@app.post("/api/bulk-export")
|
||||
@rate_limit(limit=100, window_size=60, cost=50)
|
||||
async def bulk_export(request: Request) -> dict[str, Any]:
|
||||
"""Very expensive operation - costs 50 tokens."""
|
||||
return {"export_url": "https://...", "cost": 50}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Pattern 2: Sliding scale exemptions
|
||||
# Gradually reduce limits instead of hard blocking
|
||||
# =============================================================================
|
||||
|
||||
def get_request_priority(request: Request) -> int:
|
||||
"""Determine request priority (higher = more important)."""
|
||||
# Premium users get higher priority
|
||||
if request.headers.get("X-Premium-User") == "true":
|
||||
return 100
|
||||
|
||||
# Authenticated users get medium priority
|
||||
if request.headers.get("Authorization"):
|
||||
return 50
|
||||
|
||||
# Anonymous users get lowest priority
|
||||
return 10
|
||||
|
||||
|
||||
def should_exempt_high_priority(request: Request) -> bool:
|
||||
"""Exempt high-priority requests from rate limiting."""
|
||||
return get_request_priority(request) >= 100
|
||||
|
||||
|
||||
@app.get("/api/priority-based")
|
||||
@rate_limit(
|
||||
limit=10,
|
||||
window_size=60,
|
||||
exempt_when=should_exempt_high_priority,
|
||||
)
|
||||
async def priority_endpoint(request: Request) -> dict[str, Any]:
|
||||
"""Premium users are exempt from rate limits."""
|
||||
priority = get_request_priority(request)
|
||||
return {
|
||||
"message": "Success",
|
||||
"priority": priority,
|
||||
"exempt": priority >= 100,
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Pattern 3: Rate limit by resource, not just user
|
||||
# Prevent abuse of specific resources
|
||||
# =============================================================================
|
||||
|
||||
def resource_key_extractor(request: Request) -> str:
|
||||
"""Rate limit by resource ID + user."""
|
||||
resource_id = request.path_params.get("resource_id", "unknown")
|
||||
user_id = request.headers.get("X-User-ID", "anonymous")
|
||||
return f"resource:{resource_id}:user:{user_id}"
|
||||
|
||||
|
||||
@app.get("/api/resources/{resource_id}")
|
||||
@rate_limit(
|
||||
limit=10,
|
||||
window_size=60,
|
||||
key_extractor=resource_key_extractor,
|
||||
)
|
||||
async def get_resource(request: Request, resource_id: str) -> dict[str, str]:
|
||||
"""Each user can access each resource 10 times per minute."""
|
||||
return {"resource_id": resource_id, "data": "..."}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Pattern 4: Login/authentication rate limiting
|
||||
# Prevent brute force attacks
|
||||
# =============================================================================
|
||||
|
||||
def login_key_extractor(request: Request) -> str:
|
||||
"""Rate limit by IP + username to prevent brute force."""
|
||||
ip = request.client.host if request.client else "unknown"
|
||||
# In real app, parse username from request body
|
||||
username = request.headers.get("X-Username", "unknown")
|
||||
return f"login:{ip}:{username}"
|
||||
|
||||
|
||||
@app.post("/auth/login")
|
||||
@rate_limit(
|
||||
limit=5,
|
||||
window_size=300, # 5 attempts per 5 minutes
|
||||
algorithm=Algorithm.SLIDING_WINDOW, # Precise tracking for security
|
||||
key_extractor=login_key_extractor,
|
||||
error_message="Too many login attempts. Please try again in 5 minutes.",
|
||||
)
|
||||
async def login(request: Request) -> dict[str, str]:
|
||||
"""Login endpoint with brute force protection."""
|
||||
return {"message": "Login successful", "token": "..."}
|
||||
|
||||
|
||||
# Password reset - even stricter limits
|
||||
def password_reset_key(request: Request) -> str:
|
||||
ip = request.client.host if request.client else "unknown"
|
||||
return f"password_reset:{ip}"
|
||||
|
||||
|
||||
@app.post("/auth/password-reset")
|
||||
@rate_limit(
|
||||
limit=3,
|
||||
window_size=3600, # 3 attempts per hour
|
||||
key_extractor=password_reset_key,
|
||||
error_message="Too many password reset requests. Please try again later.",
|
||||
)
|
||||
async def password_reset(request: Request) -> dict[str, str]:
|
||||
"""Password reset with strict rate limiting."""
|
||||
return {"message": "Password reset email sent"}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Pattern 5: Webhook/callback rate limiting
|
||||
# Limit outgoing requests to prevent overwhelming external services
|
||||
# =============================================================================
|
||||
|
||||
webhook_rate_limit = RateLimitDependency(
|
||||
limit=100,
|
||||
window_size=60,
|
||||
key_prefix="webhook",
|
||||
)
|
||||
|
||||
|
||||
async def check_webhook_limit(
|
||||
request: Request,
|
||||
webhook_url: str,
|
||||
) -> None:
|
||||
"""Check rate limit before sending webhook."""
|
||||
# Create key based on destination domain
|
||||
from urllib.parse import urlparse
|
||||
domain = urlparse(webhook_url).netloc
|
||||
_key = f"webhook:{domain}" # Would be used with limiter in production
|
||||
|
||||
# Manually check limit (simplified example)
|
||||
# In production, you'd use the limiter directly
|
||||
_ = _key # Suppress unused variable warning
|
||||
|
||||
|
||||
@app.post("/api/send-webhook")
|
||||
async def send_webhook(
|
||||
request: Request,
|
||||
webhook_url: str = "https://example.com/webhook",
|
||||
rate_info: Any = Depends(webhook_rate_limit),
|
||||
) -> dict[str, Any]:
|
||||
"""Send webhook with rate limiting to protect external services."""
|
||||
# await check_webhook_limit(request, webhook_url)
|
||||
return {
|
||||
"message": "Webhook sent",
|
||||
"destination": webhook_url,
|
||||
"remaining_quota": rate_info.remaining,
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Pattern 6: Request fingerprinting
|
||||
# Detect and limit similar requests (e.g., spam prevention)
|
||||
# =============================================================================
|
||||
|
||||
def request_fingerprint(request: Request) -> str:
|
||||
"""Create fingerprint based on request characteristics."""
|
||||
ip = request.client.host if request.client else "unknown"
|
||||
user_agent = request.headers.get("User-Agent", "")
|
||||
accept_language = request.headers.get("Accept-Language", "")
|
||||
|
||||
# Create hash of request characteristics
|
||||
fingerprint_data = f"{ip}:{user_agent}:{accept_language}"
|
||||
fingerprint = hashlib.md5(fingerprint_data.encode()).hexdigest()[:16]
|
||||
|
||||
return f"fingerprint:{fingerprint}"
|
||||
|
||||
|
||||
@app.post("/api/submit-form")
|
||||
@rate_limit(
|
||||
limit=5,
|
||||
window_size=60,
|
||||
key_extractor=request_fingerprint,
|
||||
error_message="Too many submissions from this device.",
|
||||
)
|
||||
async def submit_form(request: Request) -> dict[str, str]:
|
||||
"""Form submission with fingerprint-based rate limiting."""
|
||||
return {"message": "Form submitted successfully"}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Pattern 7: Time-of-day based limits
|
||||
# Different limits during peak vs off-peak hours
|
||||
# =============================================================================
|
||||
|
||||
def is_peak_hours() -> bool:
|
||||
"""Check if current time is during peak hours (9 AM - 6 PM UTC)."""
|
||||
current_hour = time.gmtime().tm_hour
|
||||
return 9 <= current_hour < 18
|
||||
|
||||
|
||||
def peak_aware_exempt(request: Request) -> bool:
|
||||
"""Exempt requests during off-peak hours."""
|
||||
return not is_peak_hours()
|
||||
|
||||
|
||||
@app.get("/api/peak-aware")
|
||||
@rate_limit(
|
||||
limit=10, # Strict limit during peak hours
|
||||
window_size=60,
|
||||
exempt_when=peak_aware_exempt, # No limit during off-peak
|
||||
)
|
||||
async def peak_aware_endpoint(request: Request) -> dict[str, Any]:
|
||||
"""Stricter limits during peak hours."""
|
||||
return {
|
||||
"message": "Success",
|
||||
"is_peak_hours": is_peak_hours(),
|
||||
"rate_limited": is_peak_hours(),
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Pattern 8: Cascading limits (multiple tiers)
|
||||
# =============================================================================
|
||||
|
||||
per_second = RateLimitDependency(limit=5, window_size=1, key_prefix="sec")
|
||||
per_minute = RateLimitDependency(limit=100, window_size=60, key_prefix="min")
|
||||
per_hour = RateLimitDependency(limit=1000, window_size=3600, key_prefix="hour")
|
||||
|
||||
|
||||
async def cascading_limits(
|
||||
request: Request,
|
||||
sec_info: Any = Depends(per_second),
|
||||
min_info: Any = Depends(per_minute),
|
||||
hour_info: Any = Depends(per_hour),
|
||||
) -> dict[str, Any]:
|
||||
"""Apply multiple rate limit tiers."""
|
||||
return {
|
||||
"per_second": {"remaining": sec_info.remaining},
|
||||
"per_minute": {"remaining": min_info.remaining},
|
||||
"per_hour": {"remaining": hour_info.remaining},
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/cascading")
|
||||
async def cascading_endpoint(
|
||||
request: Request,
|
||||
limits: dict[str, Any] = Depends(cascading_limits),
|
||||
) -> dict[str, Any]:
|
||||
"""Endpoint with per-second, per-minute, and per-hour limits."""
|
||||
return {"message": "Success", "limits": limits}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8001)
|
||||
133
examples/README.md
Normal file
133
examples/README.md
Normal file
@@ -0,0 +1,133 @@
|
||||
# FastAPI Traffic Examples
|
||||
|
||||
This directory contains comprehensive examples demonstrating how to use the `fastapi-traffic` rate limiting library.
|
||||
|
||||
## Basic Examples
|
||||
|
||||
### 01_quickstart.py
|
||||
Minimal setup to get rate limiting working. Start here if you're new to the library.
|
||||
- Basic backend and limiter setup
|
||||
- Exception handler for rate limit errors
|
||||
- Simple decorator usage
|
||||
|
||||
### 02_algorithms.py
|
||||
Demonstrates all available rate limiting algorithms:
|
||||
- **Fixed Window** - Simple, resets at fixed intervals
|
||||
- **Sliding Window** - Most precise, stores timestamps
|
||||
- **Sliding Window Counter** - Balance of precision and efficiency (default)
|
||||
- **Token Bucket** - Allows controlled bursts
|
||||
- **Leaky Bucket** - Smooths out traffic
|
||||
|
||||
### 03_backends.py
|
||||
Shows different storage backends:
|
||||
- **MemoryBackend** - Fast, ephemeral (default)
|
||||
- **SQLiteBackend** - Persistent, single-instance
|
||||
- **RedisBackend** - Distributed, multi-instance
|
||||
|
||||
### 04_key_extractors.py
|
||||
Custom key extractors for different rate limiting strategies:
|
||||
- Rate limit by IP address (default)
|
||||
- Rate limit by API key
|
||||
- Rate limit by user ID
|
||||
- Rate limit by endpoint + IP
|
||||
- Rate limit by tenant/organization
|
||||
- Composite keys (user + action)
|
||||
|
||||
### 05_middleware.py
|
||||
Middleware-based rate limiting for global protection:
|
||||
- Basic middleware setup
|
||||
- Custom configuration options
|
||||
- Path and IP exemptions
|
||||
- Alternative middleware classes
|
||||
|
||||
## Advanced Examples
|
||||
|
||||
### 06_dependency_injection.py
|
||||
Using FastAPI's dependency injection system:
|
||||
- Basic rate limit dependency
|
||||
- Tier-based rate limiting
|
||||
- Combining multiple rate limits
|
||||
- Conditional exemptions
|
||||
|
||||
### 07_redis_distributed.py
|
||||
Redis backend for distributed deployments:
|
||||
- Multi-instance rate limiting
|
||||
- Shared counters across nodes
|
||||
- Health checks and statistics
|
||||
- Fallback to memory backend
|
||||
|
||||
### 08_tiered_api.py
|
||||
Production-ready tiered API example:
|
||||
- Free, Starter, Pro, Enterprise tiers
|
||||
- Different limits per tier
|
||||
- Feature gating based on tier
|
||||
- API key validation
|
||||
|
||||
### 09_custom_responses.py
|
||||
Customizing rate limit responses:
|
||||
- Custom JSON error responses
|
||||
- Logging/monitoring callbacks
|
||||
- Different response formats (JSON, HTML, plain text)
|
||||
- Rate limit headers
|
||||
|
||||
### 10_advanced_patterns.py
|
||||
Real-world patterns and use cases:
|
||||
- **Cost-based limiting** - Different operations cost different amounts
|
||||
- **Priority exemptions** - Premium users exempt from limits
|
||||
- **Resource-based limiting** - Limit by resource ID + user
|
||||
- **Login protection** - Brute force prevention
|
||||
- **Webhook limiting** - Protect external services
|
||||
- **Request fingerprinting** - Spam prevention
|
||||
- **Time-of-day limits** - Peak vs off-peak hours
|
||||
- **Cascading limits** - Per-second, per-minute, per-hour
|
||||
|
||||
## Running Examples
|
||||
|
||||
Each example is a standalone FastAPI application. Run with:
|
||||
|
||||
```bash
|
||||
# Using uvicorn directly
|
||||
uvicorn examples.01_quickstart:app --reload
|
||||
|
||||
# Or run the file directly
|
||||
python examples/01_quickstart.py
|
||||
```
|
||||
|
||||
## Testing Rate Limits
|
||||
|
||||
Use curl or httpie to test:
|
||||
|
||||
```bash
|
||||
# Basic request
|
||||
curl http://localhost:8000/api/basic
|
||||
|
||||
# With API key
|
||||
curl -H "X-API-Key: my-key" http://localhost:8000/api/by-api-key
|
||||
|
||||
# Check rate limit headers
|
||||
curl -i http://localhost:8000/api/data
|
||||
|
||||
# Rapid requests to trigger rate limit
|
||||
for i in {1..20}; do curl http://localhost:8000/api/basic; done
|
||||
```
|
||||
|
||||
## Environment Variables
|
||||
|
||||
Some examples support configuration via environment variables:
|
||||
|
||||
- `RATE_LIMIT_BACKEND` - Backend type (memory, sqlite, redis)
|
||||
- `REDIS_URL` - Redis connection URL for distributed examples
|
||||
|
||||
## Requirements
|
||||
|
||||
Basic examples only need `fastapi-traffic` and `uvicorn`:
|
||||
|
||||
```bash
|
||||
pip install fastapi-traffic uvicorn
|
||||
```
|
||||
|
||||
For Redis examples:
|
||||
|
||||
```bash
|
||||
pip install redis
|
||||
```
|
||||
171
examples/basic_usage.py
Normal file
171
examples/basic_usage.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""Basic usage examples for fastapi-traffic."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncIterator
|
||||
|
||||
from fastapi import Depends, FastAPI, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from fastapi_traffic import (
|
||||
Algorithm,
|
||||
RateLimitExceeded,
|
||||
RateLimiter,
|
||||
SQLiteBackend,
|
||||
rate_limit,
|
||||
)
|
||||
from fastapi_traffic.core.decorator import RateLimitDependency
|
||||
from fastapi_traffic.core.limiter import set_limiter
|
||||
|
||||
# Configure global rate limiter with SQLite backend for persistence
|
||||
backend = SQLiteBackend("rate_limits.db")
|
||||
limiter = RateLimiter(backend)
|
||||
set_limiter(limiter)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
|
||||
"""Manage application lifespan - startup and shutdown."""
|
||||
# Startup: Initialize the rate limiter
|
||||
await limiter.initialize()
|
||||
yield
|
||||
# Shutdown: Cleanup
|
||||
await limiter.close()
|
||||
|
||||
|
||||
app = FastAPI(title="FastAPI Traffic Example", lifespan=lifespan)
|
||||
|
||||
|
||||
# Exception handler for rate limit exceeded
|
||||
@app.exception_handler(RateLimitExceeded)
|
||||
async def rate_limit_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
|
||||
"""Handle rate limit exceeded exceptions."""
|
||||
headers = exc.limit_info.to_headers() if exc.limit_info else {}
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"error": "rate_limit_exceeded",
|
||||
"message": exc.message,
|
||||
"retry_after": exc.retry_after,
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
|
||||
# Example 1: Basic decorator usage
|
||||
@app.get("/api/basic")
|
||||
@rate_limit(100, 60) # 100 requests per minute
|
||||
async def basic_endpoint(request: Request) -> dict[str, str]:
|
||||
"""Basic rate-limited endpoint."""
|
||||
return {"message": "Hello, World!"}
|
||||
|
||||
|
||||
# Example 2: Custom algorithm
|
||||
@app.get("/api/token-bucket")
|
||||
@rate_limit(
|
||||
limit=50,
|
||||
window_size=60,
|
||||
algorithm=Algorithm.TOKEN_BUCKET,
|
||||
burst_size=10, # Allow bursts of up to 10 requests
|
||||
)
|
||||
async def token_bucket_endpoint(request: Request) -> dict[str, str]:
|
||||
"""Endpoint using token bucket algorithm."""
|
||||
return {"message": "Token bucket rate limiting"}
|
||||
|
||||
|
||||
# Example 3: Sliding window for precise rate limiting
|
||||
@app.get("/api/sliding-window")
|
||||
@rate_limit(
|
||||
limit=30,
|
||||
window_size=60,
|
||||
algorithm=Algorithm.SLIDING_WINDOW,
|
||||
)
|
||||
async def sliding_window_endpoint(request: Request) -> dict[str, str]:
|
||||
"""Endpoint using sliding window algorithm."""
|
||||
return {"message": "Sliding window rate limiting"}
|
||||
|
||||
|
||||
# Example 4: Custom key extractor (rate limit by API key)
|
||||
def api_key_extractor(request: Request) -> str:
|
||||
"""Extract API key from header for rate limiting."""
|
||||
api_key = request.headers.get("X-API-Key", "anonymous")
|
||||
return f"api_key:{api_key}"
|
||||
|
||||
|
||||
@app.get("/api/by-api-key")
|
||||
@rate_limit(
|
||||
limit=1000,
|
||||
window_size=3600, # 1000 requests per hour
|
||||
key_extractor=api_key_extractor,
|
||||
)
|
||||
async def api_key_endpoint(request: Request) -> dict[str, str]:
|
||||
"""Endpoint rate limited by API key."""
|
||||
return {"message": "Rate limited by API key"}
|
||||
|
||||
|
||||
# Example 5: Using dependency injection
|
||||
rate_limit_dep = RateLimitDependency(limit=20, window_size=60)
|
||||
|
||||
|
||||
@app.get("/api/dependency")
|
||||
async def dependency_endpoint(
|
||||
request: Request,
|
||||
rate_info: dict[str, object] = Depends(rate_limit_dep),
|
||||
) -> dict[str, object]:
|
||||
"""Endpoint using rate limit as dependency."""
|
||||
return {
|
||||
"message": "Rate limit info available",
|
||||
"rate_limit": rate_info,
|
||||
}
|
||||
|
||||
|
||||
# Example 6: Exempt certain requests
|
||||
def is_admin(request: Request) -> bool:
|
||||
"""Check if request is from admin."""
|
||||
return request.headers.get("X-Admin-Token") == "secret-admin-token"
|
||||
|
||||
|
||||
@app.get("/api/admin-exempt")
|
||||
@rate_limit(
|
||||
limit=10,
|
||||
window_size=60,
|
||||
exempt_when=is_admin,
|
||||
)
|
||||
async def admin_exempt_endpoint(request: Request) -> dict[str, str]:
|
||||
"""Endpoint with admin exemption."""
|
||||
return {"message": "Admins are exempt from rate limiting"}
|
||||
|
||||
|
||||
# Example 7: Different costs for different operations
|
||||
@app.post("/api/expensive")
|
||||
@rate_limit(
|
||||
limit=100,
|
||||
window_size=60,
|
||||
cost=10, # This endpoint costs 10 tokens per request
|
||||
)
|
||||
async def expensive_endpoint(request: Request) -> dict[str, str]:
|
||||
"""Expensive operation that costs more tokens."""
|
||||
return {"message": "Expensive operation completed"}
|
||||
|
||||
|
||||
# Example 8: Global middleware rate limiting
|
||||
# Uncomment to enable global rate limiting
|
||||
# app.add_middleware(
|
||||
# RateLimitMiddleware,
|
||||
# limit=1000,
|
||||
# window_size=60,
|
||||
# exempt_paths={"/health", "/docs", "/openapi.json"},
|
||||
# )
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check() -> dict[str, str]:
|
||||
"""Health check endpoint (typically exempt from rate limiting)."""
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
Reference in New Issue
Block a user