style: clean up unused parameters and imports in examples

This commit is contained in:
2026-02-04 01:08:16 +00:00
parent 6bc108078f
commit d7966f7e96
11 changed files with 143 additions and 95 deletions

View File

@@ -9,8 +9,8 @@ from fastapi.responses import JSONResponse
from fastapi_traffic import ( from fastapi_traffic import (
MemoryBackend, MemoryBackend,
RateLimitExceeded,
RateLimiter, RateLimiter,
RateLimitExceeded,
rate_limit, rate_limit,
) )
from fastapi_traffic.core.limiter import set_limiter from fastapi_traffic.core.limiter import set_limiter
@@ -21,7 +21,7 @@ limiter = RateLimiter(backend)
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(_: FastAPI):
"""Lifespan context manager for startup/shutdown.""" """Lifespan context manager for startup/shutdown."""
await limiter.initialize() await limiter.initialize()
set_limiter(limiter) set_limiter(limiter)
@@ -34,7 +34,7 @@ app = FastAPI(title="Quickstart Example", lifespan=lifespan)
# Step 2: Add exception handler for rate limit errors # Step 2: Add exception handler for rate limit errors
@app.exception_handler(RateLimitExceeded) @app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse: async def rate_limit_handler(_: Request, exc: RateLimitExceeded) -> JSONResponse:
return JSONResponse( return JSONResponse(
status_code=429, status_code=429,
content={"error": "Too many requests", "retry_after": exc.retry_after}, content={"error": "Too many requests", "retry_after": exc.retry_after},
@@ -44,17 +44,17 @@ async def rate_limit_handler(request: Request, exc: RateLimitExceeded) -> JSONRe
# Step 3: Apply rate limiting to endpoints # Step 3: Apply rate limiting to endpoints
@app.get("/") @app.get("/")
@rate_limit(10, 60) # 10 requests per minute @rate_limit(10, 60) # 10 requests per minute
async def hello(request: Request) -> dict[str, str]: async def hello(_: Request) -> dict[str, str]:
return {"message": "Hello, World!"} return {"message": "Hello, World!"}
@app.get("/api/data") @app.get("/api/data")
@rate_limit(100, 60) # 100 requests per minute @rate_limit(100, 60) # 100 requests per minute
async def get_data(request: Request) -> dict[str, str]: async def get_data(_: Request) -> dict[str, str]:
return {"data": "Some important data"} return {"data": "Some important data"}
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000) uvicorn.run(app, host="127.0.0.1", port=8002)

View File

@@ -10,8 +10,8 @@ from fastapi.responses import JSONResponse
from fastapi_traffic import ( from fastapi_traffic import (
Algorithm, Algorithm,
MemoryBackend, MemoryBackend,
RateLimitExceeded,
RateLimiter, RateLimiter,
RateLimitExceeded,
rate_limit, rate_limit,
) )
from fastapi_traffic.core.limiter import set_limiter from fastapi_traffic.core.limiter import set_limiter
@@ -21,7 +21,7 @@ limiter = RateLimiter(backend)
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(_: FastAPI):
await limiter.initialize() await limiter.initialize()
set_limiter(limiter) set_limiter(limiter)
yield yield
@@ -32,7 +32,7 @@ app = FastAPI(title="Rate Limiting Algorithms", lifespan=lifespan)
@app.exception_handler(RateLimitExceeded) @app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse: async def rate_limit_handler(_: Request, exc: RateLimitExceeded) -> JSONResponse:
return JSONResponse( return JSONResponse(
status_code=429, status_code=429,
content={ content={
@@ -53,9 +53,12 @@ async def rate_limit_handler(request: Request, exc: RateLimitExceeded) -> JSONRe
window_size=60, window_size=60,
algorithm=Algorithm.FIXED_WINDOW, algorithm=Algorithm.FIXED_WINDOW,
) )
async def fixed_window(request: Request) -> dict[str, str]: async def fixed_window(_: Request) -> dict[str, str]:
"""Fixed window resets counter at fixed time intervals.""" """Fixed window resets counter at fixed time intervals."""
return {"algorithm": "fixed_window", "description": "Counter resets every 60 seconds"} return {
"algorithm": "fixed_window",
"description": "Counter resets every 60 seconds",
}
# 2. Sliding Window Log - Most precise # 2. Sliding Window Log - Most precise
@@ -67,9 +70,12 @@ async def fixed_window(request: Request) -> dict[str, str]:
window_size=60, window_size=60,
algorithm=Algorithm.SLIDING_WINDOW, algorithm=Algorithm.SLIDING_WINDOW,
) )
async def sliding_window(request: Request) -> dict[str, str]: async def sliding_window(_: Request) -> dict[str, str]:
"""Sliding window tracks exact timestamps for precise limiting.""" """Sliding window tracks exact timestamps for precise limiting."""
return {"algorithm": "sliding_window", "description": "Precise tracking with timestamp log"} return {
"algorithm": "sliding_window",
"description": "Precise tracking with timestamp log",
}
# 3. Sliding Window Counter - Balance of precision and efficiency # 3. Sliding Window Counter - Balance of precision and efficiency
@@ -81,9 +87,12 @@ async def sliding_window(request: Request) -> dict[str, str]:
window_size=60, window_size=60,
algorithm=Algorithm.SLIDING_WINDOW_COUNTER, algorithm=Algorithm.SLIDING_WINDOW_COUNTER,
) )
async def sliding_window_counter(request: Request) -> dict[str, str]: async def sliding_window_counter(_: Request) -> dict[str, str]:
"""Sliding window counter uses weighted counts from current and previous windows.""" """Sliding window counter uses weighted counts from current and previous windows."""
return {"algorithm": "sliding_window_counter", "description": "Efficient approximation"} return {
"algorithm": "sliding_window_counter",
"description": "Efficient approximation",
}
# 4. Token Bucket - Allows controlled bursts # 4. Token Bucket - Allows controlled bursts
@@ -96,7 +105,7 @@ async def sliding_window_counter(request: Request) -> dict[str, str]:
algorithm=Algorithm.TOKEN_BUCKET, algorithm=Algorithm.TOKEN_BUCKET,
burst_size=5, # Allow bursts of up to 5 requests burst_size=5, # Allow bursts of up to 5 requests
) )
async def token_bucket(request: Request) -> dict[str, str]: async def token_bucket(_: Request) -> dict[str, str]:
"""Token bucket allows bursts up to burst_size, then refills gradually.""" """Token bucket allows bursts up to burst_size, then refills gradually."""
return {"algorithm": "token_bucket", "description": "Allows controlled bursts"} return {"algorithm": "token_bucket", "description": "Allows controlled bursts"}
@@ -111,7 +120,7 @@ async def token_bucket(request: Request) -> dict[str, str]:
algorithm=Algorithm.LEAKY_BUCKET, algorithm=Algorithm.LEAKY_BUCKET,
burst_size=5, # Queue capacity burst_size=5, # Queue capacity
) )
async def leaky_bucket(request: Request) -> dict[str, str]: async def leaky_bucket(_: Request) -> dict[str, str]:
"""Leaky bucket smooths traffic to a constant rate.""" """Leaky bucket smooths traffic to a constant rate."""
return {"algorithm": "leaky_bucket", "description": "Constant output rate"} return {"algorithm": "leaky_bucket", "description": "Constant output rate"}

View File

@@ -11,8 +11,8 @@ from fastapi.responses import JSONResponse
from fastapi_traffic import ( from fastapi_traffic import (
MemoryBackend, MemoryBackend,
RateLimitExceeded,
RateLimiter, RateLimiter,
RateLimitExceeded,
SQLiteBackend, SQLiteBackend,
rate_limit, rate_limit,
) )
@@ -32,9 +32,10 @@ def get_backend():
# Redis - Required for distributed/multi-instance deployments # Redis - Required for distributed/multi-instance deployments
# Requires: pip install redis # Requires: pip install redis
try: try:
from fastapi_traffic import RedisBackend
import asyncio import asyncio
from fastapi_traffic import RedisBackend
async def create_redis(): async def create_redis():
return await RedisBackend.from_url( return await RedisBackend.from_url(
os.getenv("REDIS_URL", "redis://localhost:6379/0"), os.getenv("REDIS_URL", "redis://localhost:6379/0"),
@@ -56,7 +57,7 @@ limiter = RateLimiter(backend)
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(_: FastAPI):
await limiter.initialize() await limiter.initialize()
set_limiter(limiter) set_limiter(limiter)
yield yield
@@ -67,7 +68,7 @@ app = FastAPI(title="Storage Backends Example", lifespan=lifespan)
@app.exception_handler(RateLimitExceeded) @app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse: async def rate_limit_handler(_: Request, exc: RateLimitExceeded) -> JSONResponse:
return JSONResponse( return JSONResponse(
status_code=429, status_code=429,
content={"error": "rate_limit_exceeded", "retry_after": exc.retry_after}, content={"error": "rate_limit_exceeded", "retry_after": exc.retry_after},
@@ -76,7 +77,7 @@ async def rate_limit_handler(request: Request, exc: RateLimitExceeded) -> JSONRe
@app.get("/api/resource") @app.get("/api/resource")
@rate_limit(100, 60) @rate_limit(100, 60)
async def get_resource(request: Request) -> dict[str, str]: async def get_resource(_: Request) -> dict[str, str]:
return {"message": "Resource data", "backend": type(backend).__name__} return {"message": "Resource data", "backend": type(backend).__name__}

View File

@@ -9,8 +9,8 @@ from fastapi.responses import JSONResponse
from fastapi_traffic import ( from fastapi_traffic import (
MemoryBackend, MemoryBackend,
RateLimitExceeded,
RateLimiter, RateLimiter,
RateLimitExceeded,
rate_limit, rate_limit,
) )
from fastapi_traffic.core.limiter import set_limiter from fastapi_traffic.core.limiter import set_limiter
@@ -20,7 +20,7 @@ limiter = RateLimiter(backend)
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(_: FastAPI):
await limiter.initialize() await limiter.initialize()
set_limiter(limiter) set_limiter(limiter)
yield yield
@@ -31,7 +31,7 @@ app = FastAPI(title="Custom Key Extractors", lifespan=lifespan)
@app.exception_handler(RateLimitExceeded) @app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse: async def rate_limit_handler(_: Request, exc: RateLimitExceeded) -> JSONResponse:
return JSONResponse( return JSONResponse(
status_code=429, status_code=429,
content={"error": "rate_limit_exceeded", "retry_after": exc.retry_after}, content={"error": "rate_limit_exceeded", "retry_after": exc.retry_after},
@@ -43,7 +43,10 @@ async def rate_limit_handler(request: Request, exc: RateLimitExceeded) -> JSONRe
@rate_limit(10, 60) # Uses default IP-based key extractor @rate_limit(10, 60) # Uses default IP-based key extractor
async def by_ip(request: Request) -> dict[str, str]: async def by_ip(request: Request) -> dict[str, str]:
"""Rate limited by client IP address (default behavior).""" """Rate limited by client IP address (default behavior)."""
return {"limited_by": "ip", "client_ip": request.client.host if request.client else "unknown"} return {
"limited_by": "ip",
"client_ip": request.client.host if request.client else "unknown",
}
# 2. Rate limit by API key # 2. Rate limit by API key
@@ -99,7 +102,7 @@ def endpoint_ip_extractor(request: Request) -> str:
window_size=60, window_size=60,
key_extractor=endpoint_ip_extractor, key_extractor=endpoint_ip_extractor,
) )
async def endpoint_specific(request: Request) -> dict[str, str]: async def endpoint_specific(_: Request) -> dict[str, str]:
"""Each endpoint has its own rate limit counter.""" """Each endpoint has its own rate limit counter."""
return {"limited_by": "endpoint+ip"} return {"limited_by": "endpoint+ip"}

View File

@@ -10,8 +10,8 @@ from fastapi.responses import JSONResponse
from fastapi_traffic import ( from fastapi_traffic import (
MemoryBackend, MemoryBackend,
RateLimitExceeded,
RateLimiter, RateLimiter,
RateLimitExceeded,
) )
from fastapi_traffic.core.decorator import RateLimitDependency from fastapi_traffic.core.decorator import RateLimitDependency
from fastapi_traffic.core.limiter import set_limiter from fastapi_traffic.core.limiter import set_limiter
@@ -21,7 +21,7 @@ limiter = RateLimiter(backend)
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(_: FastAPI):
"""Lifespan context manager for startup/shutdown.""" """Lifespan context manager for startup/shutdown."""
await limiter.initialize() await limiter.initialize()
set_limiter(limiter) set_limiter(limiter)
@@ -33,7 +33,7 @@ app = FastAPI(title="Dependency Injection Example", lifespan=lifespan)
@app.exception_handler(RateLimitExceeded) @app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse: async def rate_limit_handler(_: Request, exc: RateLimitExceeded) -> JSONResponse:
return JSONResponse( return JSONResponse(
status_code=429, status_code=429,
content={"error": "rate_limit_exceeded", "retry_after": exc.retry_after}, content={"error": "rate_limit_exceeded", "retry_after": exc.retry_after},
@@ -46,7 +46,7 @@ basic_rate_limit = RateLimitDependency(limit=10, window_size=60)
@app.get("/basic") @app.get("/basic")
async def basic_endpoint( async def basic_endpoint(
request: Request, _: Request,
rate_info: Any = Depends(basic_rate_limit), rate_info: Any = Depends(basic_rate_limit),
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Access rate limit info in your endpoint logic.""" """Access rate limit info in your endpoint logic."""
@@ -131,7 +131,7 @@ api_rate_limit = RateLimitDependency(
@app.get("/api/resource") @app.get("/api/resource")
async def api_resource( async def api_resource(
request: Request, _: Request,
rate_info: Any = Depends(api_rate_limit), rate_info: Any = Depends(api_rate_limit),
) -> dict[str, Any]: ) -> dict[str, Any]:
"""API endpoint with per-API-key rate limiting.""" """API endpoint with per-API-key rate limiting."""
@@ -156,7 +156,7 @@ per_hour_limit = RateLimitDependency(
async def combined_rate_limit( async def combined_rate_limit(
request: Request, _: Request,
minute_info: Any = Depends(per_minute_limit), minute_info: Any = Depends(per_minute_limit),
hour_info: Any = Depends(per_hour_limit), hour_info: Any = Depends(per_hour_limit),
) -> dict[str, Any]: ) -> dict[str, Any]:
@@ -175,7 +175,7 @@ async def combined_rate_limit(
@app.get("/combined") @app.get("/combined")
async def combined_endpoint( async def combined_endpoint(
request: Request, _: Request,
rate_info: dict[str, Any] = Depends(combined_rate_limit), rate_info: dict[str, Any] = Depends(combined_rate_limit),
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Endpoint with multiple rate limit tiers.""" """Endpoint with multiple rate limit tiers."""
@@ -209,9 +209,13 @@ async def internal_exempt_endpoint(
return { return {
"message": "Success", "message": "Success",
"is_internal": is_internal, "is_internal": is_internal,
"rate_limit": None if is_internal else { "rate_limit": (
None
if is_internal
else {
"remaining": rate_info.remaining, "remaining": rate_info.remaining,
}, }
),
} }

View File

@@ -14,20 +14,20 @@ from __future__ import annotations
import os import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Annotated
from fastapi import Depends, FastAPI, Request from fastapi import Depends, FastAPI, Request
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from typing import Annotated
from fastapi_traffic import ( from fastapi_traffic import (
Algorithm, Algorithm,
MemoryBackend, MemoryBackend,
RateLimitExceeded,
RateLimiter, RateLimiter,
RateLimitExceeded,
rate_limit, rate_limit,
) )
from fastapi_traffic.core.limiter import set_limiter
from fastapi_traffic.backends.redis import RedisBackend from fastapi_traffic.backends.redis import RedisBackend
from fastapi_traffic.core.limiter import set_limiter
async def create_redis_backend(): async def create_redis_backend():
@@ -94,7 +94,7 @@ LimiterDep = Annotated[RateLimiter, Depends(get_limiter)]
@app.exception_handler(RateLimitExceeded) @app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse: async def rate_limit_handler(_: Request, exc: RateLimitExceeded) -> JSONResponse:
return JSONResponse( return JSONResponse(
status_code=429, status_code=429,
content={ content={
@@ -113,7 +113,7 @@ async def rate_limit_handler(request: Request, exc: RateLimitExceeded) -> JSONRe
window_size=60, window_size=60,
key_prefix="shared", key_prefix="shared",
) )
async def shared_limit(request: Request) -> dict[str, str]: async def shared_limit(_: Request) -> dict[str, str]:
"""This rate limit is shared across all application instances.""" """This rate limit is shared across all application instances."""
return { return {
"message": "Success", "message": "Success",
@@ -152,7 +152,7 @@ async def user_limit(request: Request) -> dict[str, str]:
burst_size=20, burst_size=20,
key_prefix="burst", key_prefix="burst",
) )
async def burst_allowed(request: Request) -> dict[str, str]: async def burst_allowed(_: Request) -> dict[str, str]:
"""Token bucket with Redis allows controlled bursts across instances.""" """Token bucket with Redis allows controlled bursts across instances."""
return {"message": "Burst request successful"} return {"message": "Burst request successful"}

View File

@@ -13,8 +13,8 @@ from fastapi.responses import JSONResponse
from fastapi_traffic import ( from fastapi_traffic import (
Algorithm, Algorithm,
MemoryBackend, MemoryBackend,
RateLimitExceeded,
RateLimiter, RateLimiter,
RateLimitExceeded,
) )
from fastapi_traffic.core.decorator import RateLimitDependency from fastapi_traffic.core.decorator import RateLimitDependency
from fastapi_traffic.core.limiter import set_limiter from fastapi_traffic.core.limiter import set_limiter
@@ -24,7 +24,7 @@ limiter = RateLimiter(backend)
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(_: FastAPI):
await limiter.initialize() await limiter.initialize()
set_limiter(limiter) set_limiter(limiter)
yield yield
@@ -82,7 +82,14 @@ TIER_CONFIGS: dict[Tier, TierConfig] = {
requests_per_hour=50000, requests_per_hour=50000,
requests_per_day=500000, requests_per_day=500000,
burst_size=200, burst_size=200,
features=["basic_api", "webhooks", "analytics", "priority_support", "sla", "custom_integrations"], features=[
"basic_api",
"webhooks",
"analytics",
"priority_support",
"sla",
"custom_integrations",
],
), ),
} }
@@ -109,7 +116,9 @@ async def rate_limit_handler(request: Request, exc: RateLimitExceeded) -> JSONRe
"message": exc.message, "message": exc.message,
"retry_after": exc.retry_after, "retry_after": exc.retry_after,
"tier": tier.value, "tier": tier.value,
"upgrade_url": "https://example.com/pricing" if tier != Tier.ENTERPRISE else None, "upgrade_url": (
"https://example.com/pricing" if tier != Tier.ENTERPRISE else None
),
}, },
headers=exc.limit_info.to_headers() if exc.limit_info else {}, headers=exc.limit_info.to_headers() if exc.limit_info else {},
) )
@@ -171,7 +180,7 @@ async def apply_tier_rate_limit(
@app.get("/api/v1/data") @app.get("/api/v1/data")
async def get_data( async def get_data(
request: Request, _: Request,
limit_info: dict[str, Any] = Depends(apply_tier_rate_limit), limit_info: dict[str, Any] = Depends(apply_tier_rate_limit),
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Get data with tier-based rate limiting.""" """Get data with tier-based rate limiting."""
@@ -188,7 +197,7 @@ async def get_data(
@app.get("/api/v1/analytics") @app.get("/api/v1/analytics")
async def get_analytics( async def get_analytics(
request: Request, _: Request,
limit_info: dict[str, Any] = Depends(apply_tier_rate_limit), limit_info: dict[str, Any] = Depends(apply_tier_rate_limit),
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Analytics endpoint - requires Pro tier or higher.""" """Analytics endpoint - requires Pro tier or higher."""
@@ -228,7 +237,11 @@ async def get_tier_info(
"burst_size": config.burst_size, "burst_size": config.burst_size,
}, },
"features": config.features, "features": config.features,
"upgrade_options": [t.value for t in Tier if TIER_CONFIGS[t].requests_per_minute > config.requests_per_minute], "upgrade_options": [
t.value
for t in Tier
if TIER_CONFIGS[t].requests_per_minute > config.requests_per_minute
],
} }

View File

@@ -12,8 +12,8 @@ from fastapi.responses import HTMLResponse, JSONResponse, PlainTextResponse
from fastapi_traffic import ( from fastapi_traffic import (
MemoryBackend, MemoryBackend,
RateLimitExceeded,
RateLimiter, RateLimiter,
RateLimitExceeded,
rate_limit, rate_limit,
) )
from fastapi_traffic.core.limiter import set_limiter from fastapi_traffic.core.limiter import set_limiter
@@ -26,7 +26,7 @@ limiter = RateLimiter(backend)
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(_: FastAPI):
await limiter.initialize() await limiter.initialize()
set_limiter(limiter) set_limiter(limiter)
yield yield
@@ -38,7 +38,9 @@ app = FastAPI(title="Custom Responses Example", lifespan=lifespan)
# 1. Standard JSON error response # 1. Standard JSON error response
@app.exception_handler(RateLimitExceeded) @app.exception_handler(RateLimitExceeded)
async def json_rate_limit_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse: async def json_rate_limit_handler(
request: Request, exc: RateLimitExceeded
) -> JSONResponse:
"""Standard JSON response for API clients.""" """Standard JSON response for API clients."""
headers = exc.limit_info.to_headers() if exc.limit_info else {} headers = exc.limit_info.to_headers() if exc.limit_info else {}
@@ -85,7 +87,7 @@ async def log_blocked_request(request: Request, info: Any) -> None:
window_size=60, window_size=60,
on_blocked=log_blocked_request, on_blocked=log_blocked_request,
) )
async def monitored_endpoint(request: Request) -> dict[str, str]: async def monitored_endpoint(_: Request) -> dict[str, str]:
"""Endpoint with blocked request logging.""" """Endpoint with blocked request logging."""
return {"message": "Success"} return {"message": "Success"}
@@ -97,7 +99,7 @@ async def monitored_endpoint(request: Request) -> dict[str, str]:
window_size=60, window_size=60,
error_message="Search rate limit exceeded. Please wait before searching again.", error_message="Search rate limit exceeded. Please wait before searching again.",
) )
async def search_endpoint(request: Request, q: str = "") -> dict[str, Any]: async def search_endpoint(_: Request, q: str = "") -> dict[str, Any]:
"""Search with custom error message.""" """Search with custom error message."""
return {"query": q, "results": []} return {"query": q, "results": []}
@@ -108,7 +110,7 @@ async def search_endpoint(request: Request, q: str = "") -> dict[str, Any]:
window_size=300, # 5 uploads per 5 minutes window_size=300, # 5 uploads per 5 minutes
error_message="Upload limit reached. You can upload 5 files every 5 minutes.", error_message="Upload limit reached. You can upload 5 files every 5 minutes.",
) )
async def upload_endpoint(request: Request) -> dict[str, str]: async def upload_endpoint(_: Request) -> dict[str, str]:
"""Upload with custom error message.""" """Upload with custom error message."""
return {"message": "Upload successful"} return {"message": "Upload successful"}
@@ -116,7 +118,7 @@ async def upload_endpoint(request: Request) -> dict[str, str]:
# 4. Different response formats based on Accept header # 4. Different response formats based on Accept header
@app.get("/api/flexible") @app.get("/api/flexible")
@rate_limit(limit=10, window_size=60) @rate_limit(limit=10, window_size=60)
async def flexible_endpoint(request: Request) -> dict[str, str]: async def flexible_endpoint(_: Request) -> dict[str, str]:
"""Endpoint that returns different formats.""" """Endpoint that returns different formats."""
return {"message": "Success", "data": "Some data"} return {"message": "Success", "data": "Some data"}
@@ -168,7 +170,7 @@ async def flexible_rate_limit_handler(request: Request, exc: RateLimitExceeded):
window_size=60, window_size=60,
include_headers=True, # Includes X-RateLimit-* headers include_headers=True, # Includes X-RateLimit-* headers
) )
async def verbose_headers_endpoint(request: Request) -> dict[str, Any]: async def verbose_headers_endpoint(_: Request) -> dict[str, Any]:
"""Response includes detailed rate limit headers.""" """Response includes detailed rate limit headers."""
return { return {
"message": "Check response headers for rate limit info", "message": "Check response headers for rate limit info",
@@ -181,10 +183,13 @@ async def verbose_headers_endpoint(request: Request) -> dict[str, Any]:
# 6. Graceful degradation - return cached/stale data instead of error # 6. Graceful degradation - return cached/stale data instead of error
cached_data = {"data": "Cached response", "cached_at": datetime.now(timezone.utc).isoformat()} cached_data = {
"data": "Cached response",
"cached_at": datetime.now(timezone.utc).isoformat(),
}
async def return_cached_on_limit(request: Request, info: Any) -> None: async def return_cached_on_limit(_: Request, __: Any) -> None:
"""Log when rate limited (callback doesn't prevent exception).""" """Log when rate limited (callback doesn't prevent exception)."""
logger.info("Returning cached data due to rate limit") logger.info("Returning cached data due to rate limit")
# This callback is called when blocked, but doesn't prevent the exception # This callback is called when blocked, but doesn't prevent the exception
@@ -197,9 +202,12 @@ async def return_cached_on_limit(request: Request, info: Any) -> None:
window_size=60, window_size=60,
on_blocked=return_cached_on_limit, on_blocked=return_cached_on_limit,
) )
async def graceful_endpoint(request: Request) -> dict[str, str]: async def graceful_endpoint(_: Request) -> dict[str, str]:
"""Endpoint with graceful degradation.""" """Endpoint with graceful degradation."""
return {"message": "Fresh data", "timestamp": datetime.now(timezone.utc).isoformat()} return {
"message": "Fresh data",
"timestamp": datetime.now(timezone.utc).isoformat(),
}
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -13,8 +13,8 @@ from fastapi.responses import JSONResponse
from fastapi_traffic import ( from fastapi_traffic import (
Algorithm, Algorithm,
MemoryBackend, MemoryBackend,
RateLimitExceeded,
RateLimiter, RateLimiter,
RateLimitExceeded,
rate_limit, rate_limit,
) )
from fastapi_traffic.core.decorator import RateLimitDependency from fastapi_traffic.core.decorator import RateLimitDependency
@@ -25,7 +25,7 @@ limiter = RateLimiter(backend)
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(_: FastAPI):
await limiter.initialize() await limiter.initialize()
set_limiter(limiter) set_limiter(limiter)
yield yield
@@ -36,7 +36,7 @@ app = FastAPI(title="Advanced Patterns", lifespan=lifespan)
@app.exception_handler(RateLimitExceeded) @app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse: async def rate_limit_handler(_: Request, exc: RateLimitExceeded) -> JSONResponse:
return JSONResponse( return JSONResponse(
status_code=429, status_code=429,
content={"error": "rate_limit_exceeded", "retry_after": exc.retry_after}, content={"error": "rate_limit_exceeded", "retry_after": exc.retry_after},
@@ -49,30 +49,31 @@ async def rate_limit_handler(request: Request, exc: RateLimitExceeded) -> JSONRe
# Different operations consume different amounts of quota # Different operations consume different amounts of quota
# ============================================================================= # =============================================================================
@app.get("/api/list") @app.get("/api/list")
@rate_limit(limit=100, window_size=60, cost=1) @rate_limit(limit=100, window_size=60, cost=1)
async def list_items(request: Request) -> dict[str, Any]: async def list_items(_: Request) -> dict[str, Any]:
"""Cheap operation - costs 1 token.""" """Cheap operation - costs 1 token."""
return {"items": ["a", "b", "c"], "cost": 1} return {"items": ["a", "b", "c"], "cost": 1}
@app.get("/api/details/{item_id}") @app.get("/api/details/{item_id}")
@rate_limit(limit=100, window_size=60, cost=5) @rate_limit(limit=100, window_size=60, cost=5)
async def get_details(request: Request, item_id: str) -> dict[str, Any]: async def get_details(_: Request, item_id: str) -> dict[str, Any]:
"""Medium operation - costs 5 tokens.""" """Medium operation - costs 5 tokens."""
return {"item_id": item_id, "details": "...", "cost": 5} return {"item_id": item_id, "details": "...", "cost": 5}
@app.post("/api/generate") @app.post("/api/generate")
@rate_limit(limit=100, window_size=60, cost=20) @rate_limit(limit=100, window_size=60, cost=20)
async def generate_content(request: Request) -> dict[str, Any]: async def generate_content(_: Request) -> dict[str, Any]:
"""Expensive operation - costs 20 tokens.""" """Expensive operation - costs 20 tokens."""
return {"generated": "AI-generated content...", "cost": 20} return {"generated": "AI-generated content...", "cost": 20}
@app.post("/api/bulk-export") @app.post("/api/bulk-export")
@rate_limit(limit=100, window_size=60, cost=50) @rate_limit(limit=100, window_size=60, cost=50)
async def bulk_export(request: Request) -> dict[str, Any]: async def bulk_export(_: Request) -> dict[str, Any]:
"""Very expensive operation - costs 50 tokens.""" """Very expensive operation - costs 50 tokens."""
return {"export_url": "https://...", "cost": 50} return {"export_url": "https://...", "cost": 50}
@@ -82,6 +83,7 @@ async def bulk_export(request: Request) -> dict[str, Any]:
# Gradually reduce limits instead of hard blocking # Gradually reduce limits instead of hard blocking
# ============================================================================= # =============================================================================
def get_request_priority(request: Request) -> int: def get_request_priority(request: Request) -> int:
"""Determine request priority (higher = more important).""" """Determine request priority (higher = more important)."""
# Premium users get higher priority # Premium users get higher priority
@@ -122,6 +124,7 @@ async def priority_endpoint(request: Request) -> dict[str, Any]:
# Prevent abuse of specific resources # Prevent abuse of specific resources
# ============================================================================= # =============================================================================
def resource_key_extractor(request: Request) -> str: def resource_key_extractor(request: Request) -> str:
"""Rate limit by resource ID + user.""" """Rate limit by resource ID + user."""
resource_id = request.path_params.get("resource_id", "unknown") resource_id = request.path_params.get("resource_id", "unknown")
@@ -135,7 +138,7 @@ def resource_key_extractor(request: Request) -> str:
window_size=60, window_size=60,
key_extractor=resource_key_extractor, key_extractor=resource_key_extractor,
) )
async def get_resource(request: Request, resource_id: str) -> dict[str, str]: async def get_resource(_: Request, resource_id: str) -> dict[str, str]:
"""Each user can access each resource 10 times per minute.""" """Each user can access each resource 10 times per minute."""
return {"resource_id": resource_id, "data": "..."} return {"resource_id": resource_id, "data": "..."}
@@ -145,6 +148,7 @@ async def get_resource(request: Request, resource_id: str) -> dict[str, str]:
# Prevent brute force attacks # Prevent brute force attacks
# ============================================================================= # =============================================================================
def login_key_extractor(request: Request) -> str: def login_key_extractor(request: Request) -> str:
"""Rate limit by IP + username to prevent brute force.""" """Rate limit by IP + username to prevent brute force."""
ip = request.client.host if request.client else "unknown" ip = request.client.host if request.client else "unknown"
@@ -161,7 +165,7 @@ def login_key_extractor(request: Request) -> str:
key_extractor=login_key_extractor, key_extractor=login_key_extractor,
error_message="Too many login attempts. Please try again in 5 minutes.", error_message="Too many login attempts. Please try again in 5 minutes.",
) )
async def login(request: Request) -> dict[str, str]: async def login(_: Request) -> dict[str, str]:
"""Login endpoint with brute force protection.""" """Login endpoint with brute force protection."""
return {"message": "Login successful", "token": "..."} return {"message": "Login successful", "token": "..."}
@@ -179,7 +183,7 @@ def password_reset_key(request: Request) -> str:
key_extractor=password_reset_key, key_extractor=password_reset_key,
error_message="Too many password reset requests. Please try again later.", error_message="Too many password reset requests. Please try again later.",
) )
async def password_reset(request: Request) -> dict[str, str]: async def password_reset(_: Request) -> dict[str, str]:
"""Password reset with strict rate limiting.""" """Password reset with strict rate limiting."""
return {"message": "Password reset email sent"} return {"message": "Password reset email sent"}
@@ -197,23 +201,24 @@ webhook_rate_limit = RateLimitDependency(
async def check_webhook_limit( async def check_webhook_limit(
request: Request, _: Request,
webhook_url: str, webhook_url: str,
) -> None: ) -> None:
"""Check rate limit before sending webhook.""" """Check rate limit before sending webhook."""
# Create key based on destination domain # Create key based on destination domain
from urllib.parse import urlparse from urllib.parse import urlparse
domain = urlparse(webhook_url).netloc domain = urlparse(webhook_url).netloc
_key = f"webhook:{domain}" # Would be used with limiter in production _key = f"webhook:{domain}" # Would be used with limiter in production
# Manually check limit (simplified example) # Manually check limit (simplified example)
# In production, you'd use the limiter directly # In production, you'd use the limiter directly
_ = _key # Suppress unused variable warning __ = _key # Suppress unused variable warning
@app.post("/api/send-webhook") @app.post("/api/send-webhook")
async def send_webhook( async def send_webhook(
request: Request, _: Request,
webhook_url: str = "https://example.com/webhook", webhook_url: str = "https://example.com/webhook",
rate_info: Any = Depends(webhook_rate_limit), rate_info: Any = Depends(webhook_rate_limit),
) -> dict[str, Any]: ) -> dict[str, Any]:
@@ -231,6 +236,7 @@ async def send_webhook(
# Detect and limit similar requests (e.g., spam prevention) # Detect and limit similar requests (e.g., spam prevention)
# ============================================================================= # =============================================================================
def request_fingerprint(request: Request) -> str: def request_fingerprint(request: Request) -> str:
"""Create fingerprint based on request characteristics.""" """Create fingerprint based on request characteristics."""
ip = request.client.host if request.client else "unknown" ip = request.client.host if request.client else "unknown"
@@ -251,7 +257,7 @@ def request_fingerprint(request: Request) -> str:
key_extractor=request_fingerprint, key_extractor=request_fingerprint,
error_message="Too many submissions from this device.", error_message="Too many submissions from this device.",
) )
async def submit_form(request: Request) -> dict[str, str]: async def submit_form(_: Request) -> dict[str, str]:
"""Form submission with fingerprint-based rate limiting.""" """Form submission with fingerprint-based rate limiting."""
return {"message": "Form submitted successfully"} return {"message": "Form submitted successfully"}
@@ -261,13 +267,14 @@ async def submit_form(request: Request) -> dict[str, str]:
# Different limits during peak vs off-peak hours # Different limits during peak vs off-peak hours
# ============================================================================= # =============================================================================
def is_peak_hours() -> bool: def is_peak_hours() -> bool:
"""Check if current time is during peak hours (9 AM - 6 PM UTC).""" """Check if current time is during peak hours (9 AM - 6 PM UTC)."""
current_hour = time.gmtime().tm_hour current_hour = time.gmtime().tm_hour
return 9 <= current_hour < 18 return 9 <= current_hour < 18
def peak_aware_exempt(request: Request) -> bool: def peak_aware_exempt(_: Request) -> bool:
"""Exempt requests during off-peak hours.""" """Exempt requests during off-peak hours."""
return not is_peak_hours() return not is_peak_hours()
@@ -278,7 +285,7 @@ def peak_aware_exempt(request: Request) -> bool:
window_size=60, window_size=60,
exempt_when=peak_aware_exempt, # No limit during off-peak exempt_when=peak_aware_exempt, # No limit during off-peak
) )
async def peak_aware_endpoint(request: Request) -> dict[str, Any]: async def peak_aware_endpoint(_: Request) -> dict[str, Any]:
"""Stricter limits during peak hours.""" """Stricter limits during peak hours."""
return { return {
"message": "Success", "message": "Success",
@@ -297,7 +304,7 @@ per_hour = RateLimitDependency(limit=1000, window_size=3600, key_prefix="hour")
async def cascading_limits( async def cascading_limits(
request: Request, _: Request,
sec_info: Any = Depends(per_second), sec_info: Any = Depends(per_second),
min_info: Any = Depends(per_minute), min_info: Any = Depends(per_minute),
hour_info: Any = Depends(per_hour), hour_info: Any = Depends(per_hour),
@@ -312,7 +319,7 @@ async def cascading_limits(
@app.get("/api/cascading") @app.get("/api/cascading")
async def cascading_endpoint( async def cascading_endpoint(
request: Request, _: Request,
limits: dict[str, Any] = Depends(cascading_limits), limits: dict[str, Any] = Depends(cascading_limits),
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Endpoint with per-second, per-minute, and per-hour limits.""" """Endpoint with per-second, per-minute, and per-hour limits."""

View File

@@ -346,7 +346,7 @@ def create_app_with_config() -> FastAPI:
) )
@app.exception_handler(RateLimitExceeded) @app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(_: Request, exc: RateLimitExceeded) -> JSONResponse: async def _rate_limit_handler(_: Request, exc: RateLimitExceeded) -> JSONResponse:
return JSONResponse( return JSONResponse(
status_code=429, status_code=429,
content={ content={
@@ -358,17 +358,17 @@ def create_app_with_config() -> FastAPI:
@app.get("/") @app.get("/")
@rate_limit(limit=10, window_size=60) @rate_limit(limit=10, window_size=60)
async def root(_: Request) -> dict[str, str]: async def _root(_: Request) -> dict[str, str]:
return {"message": "Hello from config-loaded app!"} return {"message": "Hello from config-loaded app!"}
@app.get("/health") @app.get("/health")
async def health() -> dict[str, str]: async def _health() -> dict[str, str]:
"""Health check - exempt from rate limiting.""" """Health check - exempt from rate limiting."""
return {"status": "healthy"} return {"status": "healthy"}
@app.get("/api/data") @app.get("/api/data")
@rate_limit(limit=50, window_size=60) @rate_limit(limit=50, window_size=60)
async def get_data(_: Request) -> dict[str, str]: async def _get_data(_: Request) -> dict[str, str]:
return {"data": "Some API data"} return {"data": "Some API data"}
return app return app

View File

@@ -3,21 +3,24 @@
from __future__ import annotations from __future__ import annotations
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import AsyncIterator from typing import TYPE_CHECKING
from fastapi import Depends, FastAPI, Request from fastapi import Depends, FastAPI, Request
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from fastapi_traffic import ( from fastapi_traffic import (
Algorithm, Algorithm,
RateLimitExceeded,
RateLimiter, RateLimiter,
RateLimitExceeded,
SQLiteBackend, SQLiteBackend,
rate_limit, rate_limit,
) )
from fastapi_traffic.core.decorator import RateLimitDependency from fastapi_traffic.core.decorator import RateLimitDependency
from fastapi_traffic.core.limiter import set_limiter from fastapi_traffic.core.limiter import set_limiter
if TYPE_CHECKING:
from collections.abc import AsyncIterator
# Configure global rate limiter with SQLite backend for persistence # Configure global rate limiter with SQLite backend for persistence
backend = SQLiteBackend("rate_limits.db") backend = SQLiteBackend("rate_limits.db")
limiter = RateLimiter(backend) limiter = RateLimiter(backend)
@@ -25,7 +28,7 @@ set_limiter(limiter)
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[None]: async def lifespan(_: FastAPI) -> AsyncIterator[None]:
"""Manage application lifespan - startup and shutdown.""" """Manage application lifespan - startup and shutdown."""
# Startup: Initialize the rate limiter # Startup: Initialize the rate limiter
await limiter.initialize() await limiter.initialize()
@@ -39,7 +42,7 @@ app = FastAPI(title="FastAPI Traffic Example", lifespan=lifespan)
# Exception handler for rate limit exceeded # Exception handler for rate limit exceeded
@app.exception_handler(RateLimitExceeded) @app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse: async def rate_limit_handler(_: Request, exc: RateLimitExceeded) -> JSONResponse:
"""Handle rate limit exceeded exceptions.""" """Handle rate limit exceeded exceptions."""
headers = exc.limit_info.to_headers() if exc.limit_info else {} headers = exc.limit_info.to_headers() if exc.limit_info else {}
return JSONResponse( return JSONResponse(
@@ -56,7 +59,7 @@ async def rate_limit_handler(request: Request, exc: RateLimitExceeded) -> JSONRe
# Example 1: Basic decorator usage # Example 1: Basic decorator usage
@app.get("/api/basic") @app.get("/api/basic")
@rate_limit(100, 60) # 100 requests per minute @rate_limit(100, 60) # 100 requests per minute
async def basic_endpoint(request: Request) -> dict[str, str]: async def basic_endpoint(_: Request) -> dict[str, str]:
"""Basic rate-limited endpoint.""" """Basic rate-limited endpoint."""
return {"message": "Hello, World!"} return {"message": "Hello, World!"}
@@ -69,7 +72,7 @@ async def basic_endpoint(request: Request) -> dict[str, str]:
algorithm=Algorithm.TOKEN_BUCKET, algorithm=Algorithm.TOKEN_BUCKET,
burst_size=10, # Allow bursts of up to 10 requests burst_size=10, # Allow bursts of up to 10 requests
) )
async def token_bucket_endpoint(request: Request) -> dict[str, str]: async def token_bucket_endpoint(_: Request) -> dict[str, str]:
"""Endpoint using token bucket algorithm.""" """Endpoint using token bucket algorithm."""
return {"message": "Token bucket rate limiting"} return {"message": "Token bucket rate limiting"}
@@ -81,7 +84,7 @@ async def token_bucket_endpoint(request: Request) -> dict[str, str]:
window_size=60, window_size=60,
algorithm=Algorithm.SLIDING_WINDOW, algorithm=Algorithm.SLIDING_WINDOW,
) )
async def sliding_window_endpoint(request: Request) -> dict[str, str]: async def sliding_window_endpoint(_: Request) -> dict[str, str]:
"""Endpoint using sliding window algorithm.""" """Endpoint using sliding window algorithm."""
return {"message": "Sliding window rate limiting"} return {"message": "Sliding window rate limiting"}
@@ -99,7 +102,7 @@ def api_key_extractor(request: Request) -> str:
window_size=3600, # 1000 requests per hour window_size=3600, # 1000 requests per hour
key_extractor=api_key_extractor, key_extractor=api_key_extractor,
) )
async def api_key_endpoint(request: Request) -> dict[str, str]: async def api_key_endpoint(_: Request) -> dict[str, str]:
"""Endpoint rate limited by API key.""" """Endpoint rate limited by API key."""
return {"message": "Rate limited by API key"} return {"message": "Rate limited by API key"}
@@ -110,7 +113,7 @@ rate_limit_dep = RateLimitDependency(limit=20, window_size=60)
@app.get("/api/dependency") @app.get("/api/dependency")
async def dependency_endpoint( async def dependency_endpoint(
request: Request, _: Request,
rate_info: dict[str, object] = Depends(rate_limit_dep), rate_info: dict[str, object] = Depends(rate_limit_dep),
) -> dict[str, object]: ) -> dict[str, object]:
"""Endpoint using rate limit as dependency.""" """Endpoint using rate limit as dependency."""
@@ -132,7 +135,7 @@ def is_admin(request: Request) -> bool:
window_size=60, window_size=60,
exempt_when=is_admin, exempt_when=is_admin,
) )
async def admin_exempt_endpoint(request: Request) -> dict[str, str]: async def admin_exempt_endpoint(_: Request) -> dict[str, str]:
"""Endpoint with admin exemption.""" """Endpoint with admin exemption."""
return {"message": "Admins are exempt from rate limiting"} return {"message": "Admins are exempt from rate limiting"}
@@ -144,7 +147,7 @@ async def admin_exempt_endpoint(request: Request) -> dict[str, str]:
window_size=60, window_size=60,
cost=10, # This endpoint costs 10 tokens per request cost=10, # This endpoint costs 10 tokens per request
) )
async def expensive_endpoint(request: Request) -> dict[str, str]: async def expensive_endpoint(_: Request) -> dict[str, str]:
"""Expensive operation that costs more tokens.""" """Expensive operation that costs more tokens."""
return {"message": "Expensive operation completed"} return {"message": "Expensive operation completed"}