- Refactor Redis backend connection handling and pool management - Update algorithm implementations with improved type annotations - Enhance config loader validation with stricter Pydantic schemas - Improve decorator and middleware error handling - Expand example scripts with better docstrings and usage patterns - Add new 00_basic_usage.py example for quick start - Reorganize examples directory structure - Fix type annotation inconsistencies across core modules - Update dependencies in pyproject.toml
130 lines
3.3 KiB
Python
130 lines
3.3 KiB
Python
"""Examples demonstrating middleware-based rate limiting."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from fastapi import FastAPI, Request
|
|
|
|
from fastapi_traffic import MemoryBackend
|
|
from fastapi_traffic.middleware import RateLimitMiddleware
|
|
|
|
# Alternative middleware options (uncomment to use):
|
|
# from fastapi_traffic.middleware import SlidingWindowMiddleware
|
|
# from fastapi_traffic.middleware import TokenBucketMiddleware
|
|
|
|
|
|
DEFAULT_HOST = "127.0.0.1"
|
|
DEFAULT_PORT = 8001
|
|
|
|
app = FastAPI(title="Middleware Rate Limiting")
|
|
|
|
|
|
# Custom key extractor for middleware
|
|
def get_client_identifier(request: Request) -> str:
|
|
"""Extract client identifier from request."""
|
|
# Check for API key first
|
|
api_key = request.headers.get("X-API-Key")
|
|
if api_key:
|
|
return f"api_key:{api_key}"
|
|
|
|
# Fall back to IP
|
|
if request.client:
|
|
return f"ip:{request.client.host}"
|
|
|
|
return "unknown"
|
|
|
|
|
|
# Option 1: Basic middleware with defaults
|
|
# Uncomment to use:
|
|
# app.add_middleware(
|
|
# RateLimitMiddleware,
|
|
# limit=100,
|
|
# window_size=60,
|
|
# )
|
|
|
|
# Option 2: Middleware with custom configuration
|
|
app.add_middleware(
|
|
RateLimitMiddleware,
|
|
limit=100,
|
|
window_size=60,
|
|
backend=MemoryBackend(),
|
|
key_prefix="global",
|
|
include_headers=True,
|
|
error_message="You have exceeded the rate limit. Please slow down.",
|
|
status_code=429,
|
|
skip_on_error=True, # Don't block requests if backend fails
|
|
exempt_paths={"/health", "/docs", "/openapi.json", "/redoc"},
|
|
exempt_ips={"127.0.0.1"}, # Exempt localhost
|
|
key_extractor=get_client_identifier,
|
|
)
|
|
|
|
|
|
# Option 3: Convenience middleware for specific algorithms
|
|
# SlidingWindowMiddleware - precise rate limiting
|
|
# app.add_middleware(
|
|
# SlidingWindowMiddleware,
|
|
# limit=100,
|
|
# window_size=60,
|
|
# )
|
|
|
|
# TokenBucketMiddleware - allows bursts
|
|
# app.add_middleware(
|
|
# TokenBucketMiddleware,
|
|
# limit=100,
|
|
# window_size=60,
|
|
# )
|
|
|
|
|
|
@app.get("/")
|
|
async def root() -> dict[str, str]:
|
|
"""Root endpoint - rate limited by middleware."""
|
|
return {"message": "Hello, World!"}
|
|
|
|
|
|
@app.get("/api/data")
|
|
async def get_data() -> dict[str, str]:
|
|
"""API endpoint - rate limited by middleware."""
|
|
return {"data": "Some important data"}
|
|
|
|
|
|
@app.get("/api/users")
|
|
async def get_users() -> dict[str, list[str]]:
|
|
"""Users endpoint - rate limited by middleware."""
|
|
return {"users": ["alice", "bob", "charlie"]}
|
|
|
|
|
|
@app.get("/health")
|
|
async def health() -> dict[str, str]:
|
|
"""Health check - exempt from rate limiting."""
|
|
return {"status": "healthy"}
|
|
|
|
|
|
@app.get("/docs-info")
|
|
async def docs_info() -> dict[str, str]:
|
|
"""Info about documentation endpoints."""
|
|
return {
|
|
"message": "Visit /docs for Swagger UI or /redoc for ReDoc",
|
|
"note": "These endpoints are exempt from rate limiting",
|
|
}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
|
|
import uvicorn
|
|
|
|
parser = argparse.ArgumentParser(description="Middleware rate limiting example")
|
|
parser.add_argument(
|
|
"--host",
|
|
default=DEFAULT_HOST,
|
|
help=f"Host to bind to (default: {DEFAULT_HOST})",
|
|
)
|
|
parser.add_argument(
|
|
"--port",
|
|
type=int,
|
|
default=DEFAULT_PORT,
|
|
help=f"Port to bind to (default: {DEFAULT_PORT})",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
uvicorn.run(app, host=args.host, port=args.port)
|