Initial commit: fastapi-traffic rate limiting library

- Core rate limiting with multiple algorithms (sliding window, token bucket, etc.)
- SQLite and memory backends
- Decorator and dependency injection patterns
- Middleware support
- Example usage files
This commit is contained in:
2026-01-09 00:26:19 +00:00
commit da496746bb
38 changed files with 5790 additions and 0 deletions

60
examples/01_quickstart.py Normal file
View File

@@ -0,0 +1,60 @@
"""Quickstart example - minimal setup to get rate limiting working."""
from __future__ import annotations
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from fastapi_traffic import (
MemoryBackend,
RateLimitExceeded,
RateLimiter,
rate_limit,
)
from fastapi_traffic.core.limiter import set_limiter
# Step 1: Create a backend and limiter
backend = MemoryBackend()
limiter = RateLimiter(backend)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Lifespan context manager for startup/shutdown."""
await limiter.initialize()
set_limiter(limiter)
yield
await limiter.close()
app = FastAPI(title="Quickstart Example", lifespan=lifespan)
# Step 2: Add exception handler for rate limit errors
@app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
return JSONResponse(
status_code=429,
content={"error": "Too many requests", "retry_after": exc.retry_after},
)
# Step 3: Apply rate limiting to endpoints
@app.get("/")
@rate_limit(10, 60) # 10 requests per minute
async def hello(request: Request) -> dict[str, str]:
return {"message": "Hello, World!"}
@app.get("/api/data")
@rate_limit(100, 60) # 100 requests per minute
async def get_data(request: Request) -> dict[str, str]:
return {"data": "Some important data"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)

122
examples/02_algorithms.py Normal file
View File

@@ -0,0 +1,122 @@
"""Examples demonstrating all available rate limiting algorithms."""
from __future__ import annotations
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from fastapi_traffic import (
Algorithm,
MemoryBackend,
RateLimitExceeded,
RateLimiter,
rate_limit,
)
from fastapi_traffic.core.limiter import set_limiter
backend = MemoryBackend()
limiter = RateLimiter(backend)
@asynccontextmanager
async def lifespan(app: FastAPI):
await limiter.initialize()
set_limiter(limiter)
yield
await limiter.close()
app = FastAPI(title="Rate Limiting Algorithms", lifespan=lifespan)
@app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
return JSONResponse(
status_code=429,
content={
"error": "rate_limit_exceeded",
"message": exc.message,
"retry_after": exc.retry_after,
},
headers=exc.limit_info.to_headers() if exc.limit_info else {},
)
# 1. Fixed Window - Simple, resets at fixed intervals
# Best for: Simple use cases, low memory usage
# Drawback: Can allow 2x burst at window boundaries
@app.get("/fixed-window")
@rate_limit(
limit=10,
window_size=60,
algorithm=Algorithm.FIXED_WINDOW,
)
async def fixed_window(request: Request) -> dict[str, str]:
"""Fixed window resets counter at fixed time intervals."""
return {"algorithm": "fixed_window", "description": "Counter resets every 60 seconds"}
# 2. Sliding Window Log - Most precise
# Best for: When accuracy is critical
# Drawback: Higher memory usage (stores all timestamps)
@app.get("/sliding-window")
@rate_limit(
limit=10,
window_size=60,
algorithm=Algorithm.SLIDING_WINDOW,
)
async def sliding_window(request: Request) -> dict[str, str]:
"""Sliding window tracks exact timestamps for precise limiting."""
return {"algorithm": "sliding_window", "description": "Precise tracking with timestamp log"}
# 3. Sliding Window Counter - Balance of precision and efficiency
# Best for: Most production use cases (default algorithm)
# Combines benefits of fixed window efficiency with sliding window precision
@app.get("/sliding-window-counter")
@rate_limit(
limit=10,
window_size=60,
algorithm=Algorithm.SLIDING_WINDOW_COUNTER,
)
async def sliding_window_counter(request: Request) -> dict[str, str]:
"""Sliding window counter uses weighted counts from current and previous windows."""
return {"algorithm": "sliding_window_counter", "description": "Efficient approximation"}
# 4. Token Bucket - Allows controlled bursts
# Best for: APIs that need to allow occasional bursts
# Tokens refill gradually, burst_size controls max burst
@app.get("/token-bucket")
@rate_limit(
limit=10,
window_size=60,
algorithm=Algorithm.TOKEN_BUCKET,
burst_size=5, # Allow bursts of up to 5 requests
)
async def token_bucket(request: Request) -> dict[str, str]:
"""Token bucket allows bursts up to burst_size, then refills gradually."""
return {"algorithm": "token_bucket", "description": "Allows controlled bursts"}
# 5. Leaky Bucket - Smooths out traffic
# Best for: Protecting downstream services from bursts
# Processes requests at a constant rate
@app.get("/leaky-bucket")
@rate_limit(
limit=10,
window_size=60,
algorithm=Algorithm.LEAKY_BUCKET,
burst_size=5, # Queue capacity
)
async def leaky_bucket(request: Request) -> dict[str, str]:
"""Leaky bucket smooths traffic to a constant rate."""
return {"algorithm": "leaky_bucket", "description": "Constant output rate"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)

108
examples/03_backends.py Normal file
View File

@@ -0,0 +1,108 @@
"""Examples demonstrating different storage backends."""
from __future__ import annotations
import os
from contextlib import asynccontextmanager
from typing import Any
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from fastapi_traffic import (
MemoryBackend,
RateLimitExceeded,
RateLimiter,
SQLiteBackend,
rate_limit,
)
from fastapi_traffic.core.limiter import set_limiter
# Choose backend based on environment
def get_backend():
"""Select appropriate backend based on environment."""
backend_type = os.getenv("RATE_LIMIT_BACKEND", "memory")
if backend_type == "sqlite":
# SQLite - Good for single-instance apps, persists across restarts
return SQLiteBackend("rate_limits.db")
elif backend_type == "redis":
# Redis - Required for distributed/multi-instance deployments
# Requires: pip install redis
try:
from fastapi_traffic import RedisBackend
import asyncio
async def create_redis():
return await RedisBackend.from_url(
os.getenv("REDIS_URL", "redis://localhost:6379/0"),
key_prefix="myapp_ratelimit",
)
return asyncio.get_event_loop().run_until_complete(create_redis())
except ImportError:
print("Redis not installed, falling back to memory backend")
return MemoryBackend()
else:
# Memory - Fast, but resets on restart, not shared across instances
return MemoryBackend()
backend = get_backend()
limiter = RateLimiter(backend)
@asynccontextmanager
async def lifespan(app: FastAPI):
await limiter.initialize()
set_limiter(limiter)
yield
await limiter.close()
app = FastAPI(title="Storage Backends Example", lifespan=lifespan)
@app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
return JSONResponse(
status_code=429,
content={"error": "rate_limit_exceeded", "retry_after": exc.retry_after},
)
@app.get("/api/resource")
@rate_limit(100, 60)
async def get_resource(request: Request) -> dict[str, str]:
return {"message": "Resource data", "backend": type(backend).__name__}
@app.get("/backend-info")
async def backend_info() -> dict[str, Any]:
"""Get information about the current backend."""
info = {
"backend_type": type(backend).__name__,
"description": "",
}
if isinstance(backend, MemoryBackend):
info["description"] = "In-memory storage, fast but ephemeral"
elif isinstance(backend, SQLiteBackend):
info["description"] = "SQLite storage, persistent, single-instance"
else:
info["description"] = "Redis storage, distributed, multi-instance"
return info
if __name__ == "__main__":
import uvicorn
# Run with different backends:
# RATE_LIMIT_BACKEND=memory python 03_backends.py
# RATE_LIMIT_BACKEND=sqlite python 03_backends.py
# RATE_LIMIT_BACKEND=redis REDIS_URL=redis://localhost:6379/0 python 03_backends.py
uvicorn.run(app, host="0.0.0.0", port=8000)

View File

@@ -0,0 +1,153 @@
"""Examples demonstrating custom key extractors for rate limiting."""
from __future__ import annotations
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from fastapi_traffic import (
MemoryBackend,
RateLimitExceeded,
RateLimiter,
rate_limit,
)
from fastapi_traffic.core.limiter import set_limiter
backend = MemoryBackend()
limiter = RateLimiter(backend)
@asynccontextmanager
async def lifespan(app: FastAPI):
await limiter.initialize()
set_limiter(limiter)
yield
await limiter.close()
app = FastAPI(title="Custom Key Extractors", lifespan=lifespan)
@app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
return JSONResponse(
status_code=429,
content={"error": "rate_limit_exceeded", "retry_after": exc.retry_after},
)
# 1. Default: Rate limit by IP address
@app.get("/by-ip")
@rate_limit(10, 60) # Uses default IP-based key extractor
async def by_ip(request: Request) -> dict[str, str]:
"""Rate limited by client IP address (default behavior)."""
return {"limited_by": "ip", "client_ip": request.client.host if request.client else "unknown"}
# 2. Rate limit by API key
def api_key_extractor(request: Request) -> str:
"""Extract API key from header."""
api_key = request.headers.get("X-API-Key", "anonymous")
return f"api_key:{api_key}"
@app.get("/by-api-key")
@rate_limit(
limit=100,
window_size=3600, # 100 requests per hour per API key
key_extractor=api_key_extractor,
)
async def by_api_key(request: Request) -> dict[str, str]:
"""Rate limited by API key from X-API-Key header."""
api_key = request.headers.get("X-API-Key", "anonymous")
return {"limited_by": "api_key", "api_key": api_key}
# 3. Rate limit by user ID (from JWT or session)
def user_id_extractor(request: Request) -> str:
"""Extract user ID from request state or header."""
# In real apps, this would come from decoded JWT or session
user_id = request.headers.get("X-User-ID", "anonymous")
return f"user:{user_id}"
@app.get("/by-user")
@rate_limit(
limit=50,
window_size=60,
key_extractor=user_id_extractor,
)
async def by_user(request: Request) -> dict[str, str]:
"""Rate limited by user ID."""
user_id = request.headers.get("X-User-ID", "anonymous")
return {"limited_by": "user_id", "user_id": user_id}
# 4. Rate limit by endpoint + IP (separate limits per endpoint)
def endpoint_ip_extractor(request: Request) -> str:
"""Combine endpoint path with IP for per-endpoint limits."""
ip = request.client.host if request.client else "unknown"
path = request.url.path
return f"endpoint:{path}:ip:{ip}"
@app.get("/endpoint-specific")
@rate_limit(
limit=5,
window_size=60,
key_extractor=endpoint_ip_extractor,
)
async def endpoint_specific(request: Request) -> dict[str, str]:
"""Each endpoint has its own rate limit counter."""
return {"limited_by": "endpoint+ip"}
# 5. Rate limit by organization/tenant (multi-tenant apps)
def tenant_extractor(request: Request) -> str:
"""Extract tenant from subdomain or header."""
# Could also parse from subdomain: tenant.example.com
tenant = request.headers.get("X-Tenant-ID", "default")
return f"tenant:{tenant}"
@app.get("/by-tenant")
@rate_limit(
limit=1000,
window_size=3600, # 1000 requests per hour per tenant
key_extractor=tenant_extractor,
)
async def by_tenant(request: Request) -> dict[str, str]:
"""Rate limited by tenant/organization."""
tenant = request.headers.get("X-Tenant-ID", "default")
return {"limited_by": "tenant", "tenant_id": tenant}
# 6. Composite key: User + Action type
def user_action_extractor(request: Request) -> str:
"""Rate limit specific actions per user."""
user_id = request.headers.get("X-User-ID", "anonymous")
action = request.query_params.get("action", "default")
return f"user:{user_id}:action:{action}"
@app.get("/user-action")
@rate_limit(
limit=10,
window_size=60,
key_extractor=user_action_extractor,
)
async def user_action(
request: Request,
action: str = "default",
) -> dict[str, str]:
"""Rate limited by user + action combination."""
user_id = request.headers.get("X-User-ID", "anonymous")
return {"limited_by": "user+action", "user_id": user_id, "action": action}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)

109
examples/05_middleware.py Normal file
View File

@@ -0,0 +1,109 @@
"""Examples demonstrating middleware-based rate limiting."""
from __future__ import annotations
from fastapi import FastAPI, Request
from fastapi_traffic import MemoryBackend
from fastapi_traffic.middleware import RateLimitMiddleware
# Alternative middleware options (uncomment to use):
# from fastapi_traffic.middleware import SlidingWindowMiddleware
# from fastapi_traffic.middleware import TokenBucketMiddleware
app = FastAPI(title="Middleware Rate Limiting")
# Custom key extractor for middleware
def get_client_identifier(request: Request) -> str:
"""Extract client identifier from request."""
# Check for API key first
api_key = request.headers.get("X-API-Key")
if api_key:
return f"api_key:{api_key}"
# Fall back to IP
if request.client:
return f"ip:{request.client.host}"
return "unknown"
# Option 1: Basic middleware with defaults
# Uncomment to use:
# app.add_middleware(
# RateLimitMiddleware,
# limit=100,
# window_size=60,
# )
# Option 2: Middleware with custom configuration
app.add_middleware(
RateLimitMiddleware,
limit=100,
window_size=60,
backend=MemoryBackend(),
key_prefix="global",
include_headers=True,
error_message="You have exceeded the rate limit. Please slow down.",
status_code=429,
skip_on_error=True, # Don't block requests if backend fails
exempt_paths={"/health", "/docs", "/openapi.json", "/redoc"},
exempt_ips={"127.0.0.1"}, # Exempt localhost
key_extractor=get_client_identifier,
)
# Option 3: Convenience middleware for specific algorithms
# SlidingWindowMiddleware - precise rate limiting
# app.add_middleware(
# SlidingWindowMiddleware,
# limit=100,
# window_size=60,
# )
# TokenBucketMiddleware - allows bursts
# app.add_middleware(
# TokenBucketMiddleware,
# limit=100,
# window_size=60,
# )
@app.get("/")
async def root() -> dict[str, str]:
"""Root endpoint - rate limited by middleware."""
return {"message": "Hello, World!"}
@app.get("/api/data")
async def get_data() -> dict[str, str]:
"""API endpoint - rate limited by middleware."""
return {"data": "Some important data"}
@app.get("/api/users")
async def get_users() -> dict[str, list[str]]:
"""Users endpoint - rate limited by middleware."""
return {"users": ["alice", "bob", "charlie"]}
@app.get("/health")
async def health() -> dict[str, str]:
"""Health check - exempt from rate limiting."""
return {"status": "healthy"}
@app.get("/docs-info")
async def docs_info() -> dict[str, str]:
"""Info about documentation endpoints."""
return {
"message": "Visit /docs for Swagger UI or /redoc for ReDoc",
"note": "These endpoints are exempt from rate limiting",
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)

View File

@@ -0,0 +1,221 @@
"""Examples demonstrating rate limiting with FastAPI dependency injection."""
from __future__ import annotations
from contextlib import asynccontextmanager
from typing import Any
from fastapi import Depends, FastAPI, Request
from fastapi.responses import JSONResponse
from fastapi_traffic import (
MemoryBackend,
RateLimitExceeded,
RateLimiter,
)
from fastapi_traffic.core.decorator import RateLimitDependency
from fastapi_traffic.core.limiter import set_limiter
backend = MemoryBackend()
limiter = RateLimiter(backend)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Lifespan context manager for startup/shutdown."""
await limiter.initialize()
set_limiter(limiter)
yield
await limiter.close()
app = FastAPI(title="Dependency Injection Example", lifespan=lifespan)
@app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
return JSONResponse(
status_code=429,
content={"error": "rate_limit_exceeded", "retry_after": exc.retry_after},
)
# 1. Basic dependency - rate limit info available in endpoint
basic_rate_limit = RateLimitDependency(limit=10, window_size=60)
@app.get("/basic")
async def basic_endpoint(
request: Request,
rate_info: Any = Depends(basic_rate_limit),
) -> dict[str, Any]:
"""Access rate limit info in your endpoint logic."""
return {
"message": "Success",
"rate_limit": {
"limit": rate_info.limit,
"remaining": rate_info.remaining,
"reset_at": rate_info.reset_at,
},
}
# 2. Different limits for different user tiers
def get_user_tier(request: Request) -> str:
"""Get user tier from header (in real app, from JWT/database)."""
return request.headers.get("X-User-Tier", "free")
free_tier_limit = RateLimitDependency(
limit=10,
window_size=60,
key_prefix="free",
)
pro_tier_limit = RateLimitDependency(
limit=100,
window_size=60,
key_prefix="pro",
)
enterprise_tier_limit = RateLimitDependency(
limit=1000,
window_size=60,
key_prefix="enterprise",
)
async def tiered_rate_limit(
request: Request,
tier: str = Depends(get_user_tier),
) -> Any:
"""Apply different rate limits based on user tier."""
if tier == "enterprise":
return await enterprise_tier_limit(request)
elif tier == "pro":
return await pro_tier_limit(request)
else:
return await free_tier_limit(request)
@app.get("/tiered")
async def tiered_endpoint(
request: Request,
rate_info: Any = Depends(tiered_rate_limit),
) -> dict[str, Any]:
"""Endpoint with tier-based rate limiting."""
tier = get_user_tier(request)
return {
"message": "Success",
"tier": tier,
"rate_limit": {
"limit": rate_info.limit,
"remaining": rate_info.remaining,
},
}
# 3. Conditional rate limiting based on request properties
def api_key_extractor(request: Request) -> str:
"""Extract API key for rate limiting."""
api_key = request.headers.get("X-API-Key", "anonymous")
return f"api:{api_key}"
api_rate_limit = RateLimitDependency(
limit=100,
window_size=3600,
key_extractor=api_key_extractor,
)
@app.get("/api/resource")
async def api_resource(
request: Request,
rate_info: Any = Depends(api_rate_limit),
) -> dict[str, Any]:
"""API endpoint with per-API-key rate limiting."""
return {
"data": "Resource data",
"requests_remaining": rate_info.remaining,
}
# 4. Combine multiple rate limits (e.g., per-minute AND per-hour)
per_minute_limit = RateLimitDependency(
limit=10,
window_size=60,
key_prefix="minute",
)
per_hour_limit = RateLimitDependency(
limit=100,
window_size=3600,
key_prefix="hour",
)
async def combined_rate_limit(
request: Request,
minute_info: Any = Depends(per_minute_limit),
hour_info: Any = Depends(per_hour_limit),
) -> dict[str, Any]:
"""Apply both per-minute and per-hour limits."""
return {
"minute": {
"limit": minute_info.limit,
"remaining": minute_info.remaining,
},
"hour": {
"limit": hour_info.limit,
"remaining": hour_info.remaining,
},
}
@app.get("/combined")
async def combined_endpoint(
request: Request,
rate_info: dict[str, Any] = Depends(combined_rate_limit),
) -> dict[str, Any]:
"""Endpoint with multiple rate limit tiers."""
return {
"message": "Success",
"rate_limits": rate_info,
}
# 5. Rate limit with custom exemption logic
def is_internal_request(request: Request) -> bool:
"""Check if request is from internal service."""
internal_token = request.headers.get("X-Internal-Token")
return internal_token == "internal-secret-token"
internal_exempt_limit = RateLimitDependency(
limit=10,
window_size=60,
exempt_when=is_internal_request,
)
@app.get("/internal-exempt")
async def internal_exempt_endpoint(
request: Request,
rate_info: Any = Depends(internal_exempt_limit),
) -> dict[str, Any]:
"""Internal requests are exempt from rate limiting."""
is_internal = is_internal_request(request)
return {
"message": "Success",
"is_internal": is_internal,
"rate_limit": None if is_internal else {
"remaining": rate_info.remaining,
},
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)

View File

@@ -0,0 +1,197 @@
"""Example demonstrating Redis backend for distributed rate limiting.
This example shows how to use Redis for rate limiting across multiple
application instances (e.g., in a Kubernetes deployment or load-balanced setup).
Requirements:
pip install redis
Environment variables:
REDIS_URL: Redis connection URL (default: redis://localhost:6379/0)
"""
from __future__ import annotations
import os
from contextlib import asynccontextmanager
from fastapi import Depends, FastAPI, Request
from fastapi.responses import JSONResponse
from typing import Annotated
from fastapi_traffic import (
Algorithm,
MemoryBackend,
RateLimitExceeded,
RateLimiter,
rate_limit,
)
from fastapi_traffic.core.limiter import set_limiter
from fastapi_traffic.backends.redis import RedisBackend
async def create_redis_backend():
"""Create Redis backend with fallback to memory."""
try:
from fastapi_traffic import RedisBackend
redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0")
backend = await RedisBackend.from_url(
redis_url,
key_prefix="myapp",
)
# Verify connection
if await backend.ping():
print(f"Connected to Redis at {redis_url}")
return backend
else:
print("Redis ping failed, falling back to memory backend")
return MemoryBackend()
except ImportError:
print("Redis package not installed. Install with: pip install redis")
print("Falling back to memory backend")
return MemoryBackend()
except Exception as e:
print(f"Failed to connect to Redis: {e}")
print("Falling back to memory backend")
return MemoryBackend()
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Lifespan context manager for startup/shutdown."""
app.state.backend = await create_redis_backend()
app.state.limiter = RateLimiter(app.state.backend)
await app.state.limiter.initialize()
set_limiter(app.state.limiter)
yield
await app.state.limiter.close()
app = FastAPI(
title="Distributed Rate Limiting with Redis",
lifespan=lifespan,
)
def get_backend(request: Request) -> RedisBackend | MemoryBackend:
"""Dependency to get the rate limiting backend."""
return request.app.state.backend
def get_limiter(request: Request) -> RateLimiter:
"""Dependency to get the rate limiter."""
return request.app.state.limiter
BackendDep = Annotated[RedisBackend | MemoryBackend, Depends(get_backend)]
LimiterDep = Annotated[RateLimiter, Depends(get_limiter)]
@app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
return JSONResponse(
status_code=429,
content={
"error": "rate_limit_exceeded",
"message": exc.message,
"retry_after": exc.retry_after,
},
headers=exc.limit_info.to_headers() if exc.limit_info else {},
)
# Rate limits are shared across all instances when using Redis
@app.get("/api/shared-limit")
@rate_limit(
limit=100,
window_size=60,
key_prefix="shared",
)
async def shared_limit(request: Request) -> dict[str, str]:
"""This rate limit is shared across all application instances."""
return {
"message": "Success",
"note": "Rate limit counter is shared via Redis",
}
# Per-user limits also work across instances
def user_extractor(request: Request) -> str:
user_id = request.headers.get("X-User-ID", "anonymous")
return f"user:{user_id}"
@app.get("/api/user-limit")
@rate_limit(
limit=50,
window_size=60,
key_extractor=user_extractor,
key_prefix="user_api",
)
async def user_limit(request: Request) -> dict[str, str]:
"""Per-user rate limit shared across instances."""
user_id = request.headers.get("X-User-ID", "anonymous")
return {
"message": "Success",
"user_id": user_id,
}
# Token bucket works well with Redis for burst handling
@app.get("/api/burst-allowed")
@rate_limit(
limit=100,
window_size=60,
algorithm=Algorithm.TOKEN_BUCKET,
burst_size=20,
key_prefix="burst",
)
async def burst_allowed(request: Request) -> dict[str, str]:
"""Token bucket with Redis allows controlled bursts across instances."""
return {"message": "Burst request successful"}
@app.get("/health")
async def health(backend: BackendDep) -> dict[str, object]:
"""Health check with Redis status."""
redis_healthy = False
backend_type = type(backend).__name__
if hasattr(backend, "ping"):
try:
redis_healthy = await backend.ping()
except Exception:
redis_healthy = False
return {
"status": "healthy",
"backend": backend_type,
"redis_connected": redis_healthy,
}
@app.get("/stats")
async def stats(backend: BackendDep) -> dict[str, object]:
"""Get rate limiting statistics from Redis."""
if hasattr(backend, "get_stats"):
try:
return await backend.get_stats()
except Exception as e:
return {"error": str(e)}
return {"message": "Stats not available for this backend"}
if __name__ == "__main__":
import uvicorn
# Run multiple instances on different ports to test distributed limiting:
# REDIS_URL=redis://localhost:6379/0 python 07_redis_distributed.py
# In another terminal:
# uvicorn 07_redis_distributed:app --port 8001
uvicorn.run(app, host="0.0.0.0", port=8000)

256
examples/08_tiered_api.py Normal file
View File

@@ -0,0 +1,256 @@
"""Example of a production-ready tiered API with different rate limits per plan."""
from __future__ import annotations
from contextlib import asynccontextmanager
from dataclasses import dataclass
from enum import Enum
from typing import Any
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
from fastapi_traffic import (
Algorithm,
MemoryBackend,
RateLimitExceeded,
RateLimiter,
)
from fastapi_traffic.core.decorator import RateLimitDependency
from fastapi_traffic.core.limiter import set_limiter
backend = MemoryBackend()
limiter = RateLimiter(backend)
@asynccontextmanager
async def lifespan(app: FastAPI):
await limiter.initialize()
set_limiter(limiter)
yield
await limiter.close()
app = FastAPI(
title="Tiered API Example",
description="API with different rate limits based on subscription tier",
lifespan=lifespan,
)
class Tier(str, Enum):
FREE = "free"
STARTER = "starter"
PRO = "pro"
ENTERPRISE = "enterprise"
@dataclass
class TierConfig:
requests_per_minute: int
requests_per_hour: int
requests_per_day: int
burst_size: int
features: list[str]
# Tier configurations
TIER_CONFIGS: dict[Tier, TierConfig] = {
Tier.FREE: TierConfig(
requests_per_minute=10,
requests_per_hour=100,
requests_per_day=500,
burst_size=5,
features=["basic_api"],
),
Tier.STARTER: TierConfig(
requests_per_minute=60,
requests_per_hour=1000,
requests_per_day=10000,
burst_size=20,
features=["basic_api", "webhooks"],
),
Tier.PRO: TierConfig(
requests_per_minute=300,
requests_per_hour=10000,
requests_per_day=100000,
burst_size=50,
features=["basic_api", "webhooks", "analytics", "priority_support"],
),
Tier.ENTERPRISE: TierConfig(
requests_per_minute=1000,
requests_per_hour=50000,
requests_per_day=500000,
burst_size=200,
features=["basic_api", "webhooks", "analytics", "priority_support", "sla", "custom_integrations"],
),
}
# Simulated API key database
API_KEYS: dict[str, dict[str, Any]] = {
"free-key-123": {"tier": Tier.FREE, "user_id": "user_1"},
"starter-key-456": {"tier": Tier.STARTER, "user_id": "user_2"},
"pro-key-789": {"tier": Tier.PRO, "user_id": "user_3"},
"enterprise-key-000": {"tier": Tier.ENTERPRISE, "user_id": "user_4"},
}
@app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
api_key = request.headers.get("X-API-Key", "")
key_info = API_KEYS.get(api_key, {})
tier = key_info.get("tier", Tier.FREE)
return JSONResponse(
status_code=429,
content={
"error": "rate_limit_exceeded",
"message": exc.message,
"retry_after": exc.retry_after,
"tier": tier.value,
"upgrade_url": "https://example.com/pricing" if tier != Tier.ENTERPRISE else None,
},
headers=exc.limit_info.to_headers() if exc.limit_info else {},
)
def get_api_key_info(request: Request) -> dict[str, Any]:
"""Validate API key and return info."""
api_key = request.headers.get("X-API-Key")
if not api_key:
raise HTTPException(status_code=401, detail="API key required")
key_info = API_KEYS.get(api_key)
if not key_info:
raise HTTPException(status_code=401, detail="Invalid API key")
return {"api_key": api_key, **key_info}
def get_tier_config(key_info: dict[str, Any] = Depends(get_api_key_info)) -> TierConfig:
"""Get rate limit config for user's tier."""
tier = key_info.get("tier", Tier.FREE)
return TIER_CONFIGS[tier]
# Create rate limit dependencies for each tier
tier_rate_limits: dict[Tier, RateLimitDependency] = {}
for tier, config in TIER_CONFIGS.items():
tier_rate_limits[tier] = RateLimitDependency(
limit=config.requests_per_minute,
window_size=60,
algorithm=Algorithm.TOKEN_BUCKET,
burst_size=config.burst_size,
key_prefix=f"tier_{tier.value}",
)
def api_key_extractor(request: Request) -> str:
"""Extract API key for rate limiting."""
api_key = request.headers.get("X-API-Key", "anonymous")
return f"api:{api_key}"
async def apply_tier_rate_limit(
request: Request,
key_info: dict[str, Any] = Depends(get_api_key_info),
) -> dict[str, Any]:
"""Apply rate limit based on user's tier."""
tier = key_info.get("tier", Tier.FREE)
rate_limit_dep = tier_rate_limits[tier]
rate_info = await rate_limit_dep(request)
return {
"tier": tier,
"config": TIER_CONFIGS[tier],
"rate_info": rate_info,
"key_info": key_info,
}
@app.get("/api/v1/data")
async def get_data(
request: Request,
limit_info: dict[str, Any] = Depends(apply_tier_rate_limit),
) -> dict[str, Any]:
"""Get data with tier-based rate limiting."""
return {
"data": {"items": ["item1", "item2", "item3"]},
"tier": limit_info["tier"].value,
"rate_limit": {
"limit": limit_info["rate_info"].limit,
"remaining": limit_info["rate_info"].remaining,
"reset_at": limit_info["rate_info"].reset_at,
},
}
@app.get("/api/v1/analytics")
async def get_analytics(
request: Request,
limit_info: dict[str, Any] = Depends(apply_tier_rate_limit),
) -> dict[str, Any]:
"""Analytics endpoint - requires Pro tier or higher."""
tier = limit_info["tier"]
config = limit_info["config"]
if "analytics" not in config.features:
raise HTTPException(
status_code=403,
detail=f"Analytics requires Pro tier or higher. Current tier: {tier.value}",
)
return {
"analytics": {
"total_requests": 12345,
"unique_users": 567,
"avg_response_time_ms": 45,
},
"tier": tier.value,
}
@app.get("/api/v1/tier-info")
async def get_tier_info(
key_info: dict[str, Any] = Depends(get_api_key_info),
) -> dict[str, Any]:
"""Get information about current tier and limits."""
tier = key_info.get("tier", Tier.FREE)
config = TIER_CONFIGS[tier]
return {
"tier": tier.value,
"limits": {
"requests_per_minute": config.requests_per_minute,
"requests_per_hour": config.requests_per_hour,
"requests_per_day": config.requests_per_day,
"burst_size": config.burst_size,
},
"features": config.features,
"upgrade_options": [t.value for t in Tier if TIER_CONFIGS[t].requests_per_minute > config.requests_per_minute],
}
@app.get("/pricing")
async def pricing() -> dict[str, Any]:
"""Public pricing information."""
return {
"tiers": {
tier.value: {
"requests_per_minute": config.requests_per_minute,
"requests_per_day": config.requests_per_day,
"features": config.features,
}
for tier, config in TIER_CONFIGS.items()
}
}
if __name__ == "__main__":
import uvicorn
# Test with different API keys:
# curl -H "X-API-Key: free-key-123" http://localhost:8000/api/v1/data
# curl -H "X-API-Key: pro-key-789" http://localhost:8000/api/v1/analytics
uvicorn.run(app, host="0.0.0.0", port=8000)

View File

@@ -0,0 +1,208 @@
"""Examples demonstrating custom rate limit responses and callbacks."""
from __future__ import annotations
import logging
from contextlib import asynccontextmanager
from datetime import datetime, timezone
from typing import Any
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, JSONResponse, PlainTextResponse
from fastapi_traffic import (
MemoryBackend,
RateLimitExceeded,
RateLimiter,
rate_limit,
)
from fastapi_traffic.core.limiter import set_limiter
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
backend = MemoryBackend()
limiter = RateLimiter(backend)
@asynccontextmanager
async def lifespan(app: FastAPI):
await limiter.initialize()
set_limiter(limiter)
yield
await limiter.close()
app = FastAPI(title="Custom Responses Example", lifespan=lifespan)
# 1. Standard JSON error response
@app.exception_handler(RateLimitExceeded)
async def json_rate_limit_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
"""Standard JSON response for API clients."""
headers = exc.limit_info.to_headers() if exc.limit_info else {}
return JSONResponse(
status_code=429,
content={
"error": {
"code": "RATE_LIMIT_EXCEEDED",
"message": exc.message,
"retry_after_seconds": exc.retry_after,
"documentation_url": "https://docs.example.com/rate-limits",
},
"request_id": request.headers.get("X-Request-ID", "unknown"),
"timestamp": datetime.now(timezone.utc).isoformat(),
},
headers=headers,
)
# 2. Callback for logging/monitoring when requests are blocked
async def log_blocked_request(request: Request, info: Any) -> None:
"""Log blocked requests for monitoring."""
client_ip = request.client.host if request.client else "unknown"
path = request.url.path
user_agent = request.headers.get("User-Agent", "unknown")
logger.warning(
"Rate limit exceeded: ip=%s path=%s user_agent=%s remaining=%s",
client_ip,
path,
user_agent,
info.remaining if info else "unknown",
)
# In production, you might:
# - Send to metrics system (Prometheus, DataDog, etc.)
# - Trigger alerts for suspicious patterns
# - Update a blocklist for repeat offenders
@app.get("/api/monitored")
@rate_limit(
limit=5,
window_size=60,
on_blocked=log_blocked_request,
)
async def monitored_endpoint(request: Request) -> dict[str, str]:
"""Endpoint with blocked request logging."""
return {"message": "Success"}
# 3. Custom error messages per endpoint
@app.get("/api/search")
@rate_limit(
limit=10,
window_size=60,
error_message="Search rate limit exceeded. Please wait before searching again.",
)
async def search_endpoint(request: Request, q: str = "") -> dict[str, Any]:
"""Search with custom error message."""
return {"query": q, "results": []}
@app.get("/api/upload")
@rate_limit(
limit=5,
window_size=300, # 5 uploads per 5 minutes
error_message="Upload limit reached. You can upload 5 files every 5 minutes.",
)
async def upload_endpoint(request: Request) -> dict[str, str]:
"""Upload with custom error message."""
return {"message": "Upload successful"}
# 4. Different response formats based on Accept header
@app.get("/api/flexible")
@rate_limit(limit=10, window_size=60)
async def flexible_endpoint(request: Request) -> dict[str, str]:
"""Endpoint that returns different formats."""
return {"message": "Success", "data": "Some data"}
# Custom exception handler that respects Accept header
@app.exception_handler(RateLimitExceeded)
async def flexible_rate_limit_handler(request: Request, exc: RateLimitExceeded):
"""Return response in format matching Accept header."""
accept = request.headers.get("Accept", "application/json")
headers = exc.limit_info.to_headers() if exc.limit_info else {}
if "text/html" in accept:
html_content = f"""
<!DOCTYPE html>
<html>
<head><title>Rate Limit Exceeded</title></head>
<body>
<h1>429 - Too Many Requests</h1>
<p>{exc.message}</p>
<p>Please try again in {exc.retry_after:.0f} seconds.</p>
</body>
</html>
"""
return HTMLResponse(content=html_content, status_code=429, headers=headers)
elif "text/plain" in accept:
return PlainTextResponse(
content=f"Rate limit exceeded. Retry after {exc.retry_after:.0f} seconds.",
status_code=429,
headers=headers,
)
else:
return JSONResponse(
status_code=429,
content={
"error": "rate_limit_exceeded",
"message": exc.message,
"retry_after": exc.retry_after,
},
headers=headers,
)
# 5. Include helpful information in response headers
@app.get("/api/verbose-headers")
@rate_limit(
limit=10,
window_size=60,
include_headers=True, # Includes X-RateLimit-* headers
)
async def verbose_headers_endpoint(request: Request) -> dict[str, Any]:
"""Response includes detailed rate limit headers."""
return {
"message": "Check response headers for rate limit info",
"headers_included": [
"X-RateLimit-Limit",
"X-RateLimit-Remaining",
"X-RateLimit-Reset",
],
}
# 6. Graceful degradation - return cached/stale data instead of error
cached_data = {"data": "Cached response", "cached_at": datetime.now(timezone.utc).isoformat()}
async def return_cached_on_limit(request: Request, info: Any) -> None:
"""Log when rate limited (callback doesn't prevent exception)."""
logger.info("Returning cached data due to rate limit")
# This callback is called when blocked, but doesn't prevent the exception
# To actually return cached data, you'd need custom middleware
@app.get("/api/graceful")
@rate_limit(
limit=5,
window_size=60,
on_blocked=return_cached_on_limit,
)
async def graceful_endpoint(request: Request) -> dict[str, str]:
"""Endpoint with graceful degradation."""
return {"message": "Fresh data", "timestamp": datetime.now(timezone.utc).isoformat()}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)

View File

@@ -0,0 +1,325 @@
"""Advanced patterns and real-world use cases for rate limiting."""
from __future__ import annotations
import hashlib
import time
from contextlib import asynccontextmanager
from typing import Any
from fastapi import Depends, FastAPI, Request
from fastapi.responses import JSONResponse
from fastapi_traffic import (
Algorithm,
MemoryBackend,
RateLimitExceeded,
RateLimiter,
rate_limit,
)
from fastapi_traffic.core.decorator import RateLimitDependency
from fastapi_traffic.core.limiter import set_limiter
backend = MemoryBackend()
limiter = RateLimiter(backend)
@asynccontextmanager
async def lifespan(app: FastAPI):
await limiter.initialize()
set_limiter(limiter)
yield
await limiter.close()
app = FastAPI(title="Advanced Patterns", lifespan=lifespan)
@app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
return JSONResponse(
status_code=429,
content={"error": "rate_limit_exceeded", "retry_after": exc.retry_after},
headers=exc.limit_info.to_headers() if exc.limit_info else {},
)
# =============================================================================
# Pattern 1: Cost-based rate limiting
# Different operations consume different amounts of quota
# =============================================================================
@app.get("/api/list")
@rate_limit(limit=100, window_size=60, cost=1)
async def list_items(request: Request) -> dict[str, Any]:
"""Cheap operation - costs 1 token."""
return {"items": ["a", "b", "c"], "cost": 1}
@app.get("/api/details/{item_id}")
@rate_limit(limit=100, window_size=60, cost=5)
async def get_details(request: Request, item_id: str) -> dict[str, Any]:
"""Medium operation - costs 5 tokens."""
return {"item_id": item_id, "details": "...", "cost": 5}
@app.post("/api/generate")
@rate_limit(limit=100, window_size=60, cost=20)
async def generate_content(request: Request) -> dict[str, Any]:
"""Expensive operation - costs 20 tokens."""
return {"generated": "AI-generated content...", "cost": 20}
@app.post("/api/bulk-export")
@rate_limit(limit=100, window_size=60, cost=50)
async def bulk_export(request: Request) -> dict[str, Any]:
"""Very expensive operation - costs 50 tokens."""
return {"export_url": "https://...", "cost": 50}
# =============================================================================
# Pattern 2: Sliding scale exemptions
# Gradually reduce limits instead of hard blocking
# =============================================================================
def get_request_priority(request: Request) -> int:
"""Determine request priority (higher = more important)."""
# Premium users get higher priority
if request.headers.get("X-Premium-User") == "true":
return 100
# Authenticated users get medium priority
if request.headers.get("Authorization"):
return 50
# Anonymous users get lowest priority
return 10
def should_exempt_high_priority(request: Request) -> bool:
"""Exempt high-priority requests from rate limiting."""
return get_request_priority(request) >= 100
@app.get("/api/priority-based")
@rate_limit(
limit=10,
window_size=60,
exempt_when=should_exempt_high_priority,
)
async def priority_endpoint(request: Request) -> dict[str, Any]:
"""Premium users are exempt from rate limits."""
priority = get_request_priority(request)
return {
"message": "Success",
"priority": priority,
"exempt": priority >= 100,
}
# =============================================================================
# Pattern 3: Rate limit by resource, not just user
# Prevent abuse of specific resources
# =============================================================================
def resource_key_extractor(request: Request) -> str:
"""Rate limit by resource ID + user."""
resource_id = request.path_params.get("resource_id", "unknown")
user_id = request.headers.get("X-User-ID", "anonymous")
return f"resource:{resource_id}:user:{user_id}"
@app.get("/api/resources/{resource_id}")
@rate_limit(
limit=10,
window_size=60,
key_extractor=resource_key_extractor,
)
async def get_resource(request: Request, resource_id: str) -> dict[str, str]:
"""Each user can access each resource 10 times per minute."""
return {"resource_id": resource_id, "data": "..."}
# =============================================================================
# Pattern 4: Login/authentication rate limiting
# Prevent brute force attacks
# =============================================================================
def login_key_extractor(request: Request) -> str:
"""Rate limit by IP + username to prevent brute force."""
ip = request.client.host if request.client else "unknown"
# In real app, parse username from request body
username = request.headers.get("X-Username", "unknown")
return f"login:{ip}:{username}"
@app.post("/auth/login")
@rate_limit(
limit=5,
window_size=300, # 5 attempts per 5 minutes
algorithm=Algorithm.SLIDING_WINDOW, # Precise tracking for security
key_extractor=login_key_extractor,
error_message="Too many login attempts. Please try again in 5 minutes.",
)
async def login(request: Request) -> dict[str, str]:
"""Login endpoint with brute force protection."""
return {"message": "Login successful", "token": "..."}
# Password reset - even stricter limits
def password_reset_key(request: Request) -> str:
ip = request.client.host if request.client else "unknown"
return f"password_reset:{ip}"
@app.post("/auth/password-reset")
@rate_limit(
limit=3,
window_size=3600, # 3 attempts per hour
key_extractor=password_reset_key,
error_message="Too many password reset requests. Please try again later.",
)
async def password_reset(request: Request) -> dict[str, str]:
"""Password reset with strict rate limiting."""
return {"message": "Password reset email sent"}
# =============================================================================
# Pattern 5: Webhook/callback rate limiting
# Limit outgoing requests to prevent overwhelming external services
# =============================================================================
webhook_rate_limit = RateLimitDependency(
limit=100,
window_size=60,
key_prefix="webhook",
)
async def check_webhook_limit(
request: Request,
webhook_url: str,
) -> None:
"""Check rate limit before sending webhook."""
# Create key based on destination domain
from urllib.parse import urlparse
domain = urlparse(webhook_url).netloc
_key = f"webhook:{domain}" # Would be used with limiter in production
# Manually check limit (simplified example)
# In production, you'd use the limiter directly
_ = _key # Suppress unused variable warning
@app.post("/api/send-webhook")
async def send_webhook(
request: Request,
webhook_url: str = "https://example.com/webhook",
rate_info: Any = Depends(webhook_rate_limit),
) -> dict[str, Any]:
"""Send webhook with rate limiting to protect external services."""
# await check_webhook_limit(request, webhook_url)
return {
"message": "Webhook sent",
"destination": webhook_url,
"remaining_quota": rate_info.remaining,
}
# =============================================================================
# Pattern 6: Request fingerprinting
# Detect and limit similar requests (e.g., spam prevention)
# =============================================================================
def request_fingerprint(request: Request) -> str:
"""Create fingerprint based on request characteristics."""
ip = request.client.host if request.client else "unknown"
user_agent = request.headers.get("User-Agent", "")
accept_language = request.headers.get("Accept-Language", "")
# Create hash of request characteristics
fingerprint_data = f"{ip}:{user_agent}:{accept_language}"
fingerprint = hashlib.md5(fingerprint_data.encode()).hexdigest()[:16]
return f"fingerprint:{fingerprint}"
@app.post("/api/submit-form")
@rate_limit(
limit=5,
window_size=60,
key_extractor=request_fingerprint,
error_message="Too many submissions from this device.",
)
async def submit_form(request: Request) -> dict[str, str]:
"""Form submission with fingerprint-based rate limiting."""
return {"message": "Form submitted successfully"}
# =============================================================================
# Pattern 7: Time-of-day based limits
# Different limits during peak vs off-peak hours
# =============================================================================
def is_peak_hours() -> bool:
"""Check if current time is during peak hours (9 AM - 6 PM UTC)."""
current_hour = time.gmtime().tm_hour
return 9 <= current_hour < 18
def peak_aware_exempt(request: Request) -> bool:
"""Exempt requests during off-peak hours."""
return not is_peak_hours()
@app.get("/api/peak-aware")
@rate_limit(
limit=10, # Strict limit during peak hours
window_size=60,
exempt_when=peak_aware_exempt, # No limit during off-peak
)
async def peak_aware_endpoint(request: Request) -> dict[str, Any]:
"""Stricter limits during peak hours."""
return {
"message": "Success",
"is_peak_hours": is_peak_hours(),
"rate_limited": is_peak_hours(),
}
# =============================================================================
# Pattern 8: Cascading limits (multiple tiers)
# =============================================================================
per_second = RateLimitDependency(limit=5, window_size=1, key_prefix="sec")
per_minute = RateLimitDependency(limit=100, window_size=60, key_prefix="min")
per_hour = RateLimitDependency(limit=1000, window_size=3600, key_prefix="hour")
async def cascading_limits(
request: Request,
sec_info: Any = Depends(per_second),
min_info: Any = Depends(per_minute),
hour_info: Any = Depends(per_hour),
) -> dict[str, Any]:
"""Apply multiple rate limit tiers."""
return {
"per_second": {"remaining": sec_info.remaining},
"per_minute": {"remaining": min_info.remaining},
"per_hour": {"remaining": hour_info.remaining},
}
@app.get("/api/cascading")
async def cascading_endpoint(
request: Request,
limits: dict[str, Any] = Depends(cascading_limits),
) -> dict[str, Any]:
"""Endpoint with per-second, per-minute, and per-hour limits."""
return {"message": "Success", "limits": limits}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8001)

133
examples/README.md Normal file
View File

@@ -0,0 +1,133 @@
# FastAPI Traffic Examples
This directory contains comprehensive examples demonstrating how to use the `fastapi-traffic` rate limiting library.
## Basic Examples
### 01_quickstart.py
Minimal setup to get rate limiting working. Start here if you're new to the library.
- Basic backend and limiter setup
- Exception handler for rate limit errors
- Simple decorator usage
### 02_algorithms.py
Demonstrates all available rate limiting algorithms:
- **Fixed Window** - Simple, resets at fixed intervals
- **Sliding Window** - Most precise, stores timestamps
- **Sliding Window Counter** - Balance of precision and efficiency (default)
- **Token Bucket** - Allows controlled bursts
- **Leaky Bucket** - Smooths out traffic
### 03_backends.py
Shows different storage backends:
- **MemoryBackend** - Fast, ephemeral (default)
- **SQLiteBackend** - Persistent, single-instance
- **RedisBackend** - Distributed, multi-instance
### 04_key_extractors.py
Custom key extractors for different rate limiting strategies:
- Rate limit by IP address (default)
- Rate limit by API key
- Rate limit by user ID
- Rate limit by endpoint + IP
- Rate limit by tenant/organization
- Composite keys (user + action)
### 05_middleware.py
Middleware-based rate limiting for global protection:
- Basic middleware setup
- Custom configuration options
- Path and IP exemptions
- Alternative middleware classes
## Advanced Examples
### 06_dependency_injection.py
Using FastAPI's dependency injection system:
- Basic rate limit dependency
- Tier-based rate limiting
- Combining multiple rate limits
- Conditional exemptions
### 07_redis_distributed.py
Redis backend for distributed deployments:
- Multi-instance rate limiting
- Shared counters across nodes
- Health checks and statistics
- Fallback to memory backend
### 08_tiered_api.py
Production-ready tiered API example:
- Free, Starter, Pro, Enterprise tiers
- Different limits per tier
- Feature gating based on tier
- API key validation
### 09_custom_responses.py
Customizing rate limit responses:
- Custom JSON error responses
- Logging/monitoring callbacks
- Different response formats (JSON, HTML, plain text)
- Rate limit headers
### 10_advanced_patterns.py
Real-world patterns and use cases:
- **Cost-based limiting** - Different operations cost different amounts
- **Priority exemptions** - Premium users exempt from limits
- **Resource-based limiting** - Limit by resource ID + user
- **Login protection** - Brute force prevention
- **Webhook limiting** - Protect external services
- **Request fingerprinting** - Spam prevention
- **Time-of-day limits** - Peak vs off-peak hours
- **Cascading limits** - Per-second, per-minute, per-hour
## Running Examples
Each example is a standalone FastAPI application. Run with:
```bash
# Using uvicorn directly
uvicorn examples.01_quickstart:app --reload
# Or run the file directly
python examples/01_quickstart.py
```
## Testing Rate Limits
Use curl or httpie to test:
```bash
# Basic request
curl http://localhost:8000/api/basic
# With API key
curl -H "X-API-Key: my-key" http://localhost:8000/api/by-api-key
# Check rate limit headers
curl -i http://localhost:8000/api/data
# Rapid requests to trigger rate limit
for i in {1..20}; do curl http://localhost:8000/api/basic; done
```
## Environment Variables
Some examples support configuration via environment variables:
- `RATE_LIMIT_BACKEND` - Backend type (memory, sqlite, redis)
- `REDIS_URL` - Redis connection URL for distributed examples
## Requirements
Basic examples only need `fastapi-traffic` and `uvicorn`:
```bash
pip install fastapi-traffic uvicorn
```
For Redis examples:
```bash
pip install redis
```

171
examples/basic_usage.py Normal file
View File

@@ -0,0 +1,171 @@
"""Basic usage examples for fastapi-traffic."""
from __future__ import annotations
from contextlib import asynccontextmanager
from typing import AsyncIterator
from fastapi import Depends, FastAPI, Request
from fastapi.responses import JSONResponse
from fastapi_traffic import (
Algorithm,
RateLimitExceeded,
RateLimiter,
SQLiteBackend,
rate_limit,
)
from fastapi_traffic.core.decorator import RateLimitDependency
from fastapi_traffic.core.limiter import set_limiter
# Configure global rate limiter with SQLite backend for persistence
backend = SQLiteBackend("rate_limits.db")
limiter = RateLimiter(backend)
set_limiter(limiter)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
"""Manage application lifespan - startup and shutdown."""
# Startup: Initialize the rate limiter
await limiter.initialize()
yield
# Shutdown: Cleanup
await limiter.close()
app = FastAPI(title="FastAPI Traffic Example", lifespan=lifespan)
# Exception handler for rate limit exceeded
@app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
"""Handle rate limit exceeded exceptions."""
headers = exc.limit_info.to_headers() if exc.limit_info else {}
return JSONResponse(
status_code=429,
content={
"error": "rate_limit_exceeded",
"message": exc.message,
"retry_after": exc.retry_after,
},
headers=headers,
)
# Example 1: Basic decorator usage
@app.get("/api/basic")
@rate_limit(100, 60) # 100 requests per minute
async def basic_endpoint(request: Request) -> dict[str, str]:
"""Basic rate-limited endpoint."""
return {"message": "Hello, World!"}
# Example 2: Custom algorithm
@app.get("/api/token-bucket")
@rate_limit(
limit=50,
window_size=60,
algorithm=Algorithm.TOKEN_BUCKET,
burst_size=10, # Allow bursts of up to 10 requests
)
async def token_bucket_endpoint(request: Request) -> dict[str, str]:
"""Endpoint using token bucket algorithm."""
return {"message": "Token bucket rate limiting"}
# Example 3: Sliding window for precise rate limiting
@app.get("/api/sliding-window")
@rate_limit(
limit=30,
window_size=60,
algorithm=Algorithm.SLIDING_WINDOW,
)
async def sliding_window_endpoint(request: Request) -> dict[str, str]:
"""Endpoint using sliding window algorithm."""
return {"message": "Sliding window rate limiting"}
# Example 4: Custom key extractor (rate limit by API key)
def api_key_extractor(request: Request) -> str:
"""Extract API key from header for rate limiting."""
api_key = request.headers.get("X-API-Key", "anonymous")
return f"api_key:{api_key}"
@app.get("/api/by-api-key")
@rate_limit(
limit=1000,
window_size=3600, # 1000 requests per hour
key_extractor=api_key_extractor,
)
async def api_key_endpoint(request: Request) -> dict[str, str]:
"""Endpoint rate limited by API key."""
return {"message": "Rate limited by API key"}
# Example 5: Using dependency injection
rate_limit_dep = RateLimitDependency(limit=20, window_size=60)
@app.get("/api/dependency")
async def dependency_endpoint(
request: Request,
rate_info: dict[str, object] = Depends(rate_limit_dep),
) -> dict[str, object]:
"""Endpoint using rate limit as dependency."""
return {
"message": "Rate limit info available",
"rate_limit": rate_info,
}
# Example 6: Exempt certain requests
def is_admin(request: Request) -> bool:
"""Check if request is from admin."""
return request.headers.get("X-Admin-Token") == "secret-admin-token"
@app.get("/api/admin-exempt")
@rate_limit(
limit=10,
window_size=60,
exempt_when=is_admin,
)
async def admin_exempt_endpoint(request: Request) -> dict[str, str]:
"""Endpoint with admin exemption."""
return {"message": "Admins are exempt from rate limiting"}
# Example 7: Different costs for different operations
@app.post("/api/expensive")
@rate_limit(
limit=100,
window_size=60,
cost=10, # This endpoint costs 10 tokens per request
)
async def expensive_endpoint(request: Request) -> dict[str, str]:
"""Expensive operation that costs more tokens."""
return {"message": "Expensive operation completed"}
# Example 8: Global middleware rate limiting
# Uncomment to enable global rate limiting
# app.add_middleware(
# RateLimitMiddleware,
# limit=1000,
# window_size=60,
# exempt_paths={"/health", "/docs", "/openapi.json"},
# )
@app.get("/health")
async def health_check() -> dict[str, str]:
"""Health check endpoint (typically exempt from rate limiting)."""
return {"status": "healthy"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)