- Refactor Redis backend connection handling and pool management - Update algorithm implementations with improved type annotations - Enhance config loader validation with stricter Pydantic schemas - Improve decorator and middleware error handling - Expand example scripts with better docstrings and usage patterns - Add new 00_basic_usage.py example for quick start - Reorganize examples directory structure - Fix type annotation inconsistencies across core modules - Update dependencies in pyproject.toml
259 lines
7.3 KiB
ReStructuredText
259 lines
7.3 KiB
ReStructuredText
Key Extractors
|
|
==============
|
|
|
|
A key extractor is a function that identifies who's making a request. By default,
|
|
FastAPI Traffic uses the client's IP address, but you can customize this to fit
|
|
your authentication model.
|
|
|
|
How It Works
|
|
------------
|
|
|
|
Every rate limit needs a way to group requests. The key extractor returns a string
|
|
that identifies the client:
|
|
|
|
.. code-block:: python
|
|
|
|
def my_key_extractor(request: Request) -> str:
|
|
return "some-unique-identifier"
|
|
|
|
All requests that return the same identifier share the same rate limit bucket.
|
|
|
|
Default Behavior
|
|
----------------
|
|
|
|
The default extractor looks for the client IP in this order:
|
|
|
|
1. ``X-Forwarded-For`` header (first IP in the list)
|
|
2. ``X-Real-IP`` header
|
|
3. Direct connection IP (``request.client.host``)
|
|
4. Falls back to ``"unknown"``
|
|
|
|
This handles most reverse proxy setups automatically.
|
|
|
|
Rate Limiting by API Key
|
|
------------------------
|
|
|
|
For authenticated APIs, you probably want to limit by API key:
|
|
|
|
.. code-block:: python
|
|
|
|
from fastapi import Request
|
|
from fastapi_traffic import rate_limit
|
|
|
|
def api_key_extractor(request: Request) -> str:
|
|
"""Rate limit by API key."""
|
|
api_key = request.headers.get("X-API-Key")
|
|
if api_key:
|
|
return f"apikey:{api_key}"
|
|
# Fall back to IP for unauthenticated requests
|
|
return f"ip:{request.client.host}" if request.client else "ip:unknown"
|
|
|
|
@app.get("/api/data")
|
|
@rate_limit(1000, 3600, key_extractor=api_key_extractor)
|
|
async def get_data(request: Request):
|
|
return {"data": "here"}
|
|
|
|
Now each API key gets its own rate limit bucket.
|
|
|
|
Rate Limiting by User
|
|
---------------------
|
|
|
|
If you're using authentication middleware that sets the user:
|
|
|
|
.. code-block:: python
|
|
|
|
def user_extractor(request: Request) -> str:
|
|
"""Rate limit by authenticated user."""
|
|
# Assuming your auth middleware sets request.state.user
|
|
user = getattr(request.state, "user", None)
|
|
if user:
|
|
return f"user:{user.id}"
|
|
return f"ip:{request.client.host}" if request.client else "ip:unknown"
|
|
|
|
@app.get("/api/profile")
|
|
@rate_limit(100, 60, key_extractor=user_extractor)
|
|
async def get_profile(request: Request):
|
|
return {"profile": "data"}
|
|
|
|
Rate Limiting by Tenant
|
|
-----------------------
|
|
|
|
For multi-tenant applications:
|
|
|
|
.. code-block:: python
|
|
|
|
def tenant_extractor(request: Request) -> str:
|
|
"""Rate limit by tenant."""
|
|
# From subdomain
|
|
host = request.headers.get("host", "")
|
|
if "." in host:
|
|
tenant = host.split(".")[0]
|
|
return f"tenant:{tenant}"
|
|
|
|
# Or from header
|
|
tenant = request.headers.get("X-Tenant-ID")
|
|
if tenant:
|
|
return f"tenant:{tenant}"
|
|
|
|
return "tenant:default"
|
|
|
|
Combining Identifiers
|
|
---------------------
|
|
|
|
Sometimes you want to combine multiple factors:
|
|
|
|
.. code-block:: python
|
|
|
|
def combined_extractor(request: Request) -> str:
|
|
"""Rate limit by user AND endpoint."""
|
|
user = getattr(request.state, "user", None)
|
|
user_id = user.id if user else "anonymous"
|
|
endpoint = request.url.path
|
|
return f"{user_id}:{endpoint}"
|
|
|
|
This gives each user a separate limit for each endpoint.
|
|
|
|
Tiered Rate Limits
|
|
------------------
|
|
|
|
Different users might have different limits. Handle this with a custom extractor
|
|
that includes the tier:
|
|
|
|
.. code-block:: python
|
|
|
|
def tiered_extractor(request: Request) -> str:
|
|
"""Include tier in the key for different limits."""
|
|
user = getattr(request.state, "user", None)
|
|
if user:
|
|
# Premium users get a different bucket
|
|
tier = "premium" if user.is_premium else "free"
|
|
return f"{tier}:{user.id}"
|
|
return f"anonymous:{request.client.host}"
|
|
|
|
Then apply different limits based on tier:
|
|
|
|
.. code-block:: python
|
|
|
|
# You'd typically do this with middleware or dependency injection
|
|
# to check the tier and apply the appropriate limit
|
|
|
|
@app.get("/api/data")
|
|
async def get_data(request: Request):
|
|
user = getattr(request.state, "user", None)
|
|
if user and user.is_premium:
|
|
# Premium: 10000 req/hour
|
|
limit, window = 10000, 3600
|
|
else:
|
|
# Free: 100 req/hour
|
|
limit, window = 100, 3600
|
|
|
|
# Apply rate limit manually
|
|
limiter = get_limiter()
|
|
config = RateLimitConfig(limit=limit, window_size=window)
|
|
await limiter.hit(request, config)
|
|
|
|
return {"data": "here"}
|
|
|
|
Geographic Rate Limiting
|
|
------------------------
|
|
|
|
Limit by country or region:
|
|
|
|
.. code-block:: python
|
|
|
|
def geo_extractor(request: Request) -> str:
|
|
"""Rate limit by country."""
|
|
# Assuming you have a GeoIP lookup
|
|
country = request.headers.get("CF-IPCountry", "XX") # Cloudflare header
|
|
ip = request.client.host if request.client else "unknown"
|
|
return f"{country}:{ip}"
|
|
|
|
This lets you apply different limits to different regions if needed.
|
|
|
|
Endpoint-Specific Keys
|
|
----------------------
|
|
|
|
Rate limit the same user differently per endpoint:
|
|
|
|
.. code-block:: python
|
|
|
|
def endpoint_user_extractor(request: Request) -> str:
|
|
"""Separate limits per endpoint per user."""
|
|
user = getattr(request.state, "user", None)
|
|
user_id = user.id if user else request.client.host
|
|
method = request.method
|
|
path = request.url.path
|
|
return f"{user_id}:{method}:{path}"
|
|
|
|
Best Practices
|
|
--------------
|
|
|
|
1. **Always have a fallback.** If your primary identifier isn't available, fall
|
|
back to IP:
|
|
|
|
.. code-block:: python
|
|
|
|
def safe_extractor(request: Request) -> str:
|
|
api_key = request.headers.get("X-API-Key")
|
|
if api_key:
|
|
return f"key:{api_key}"
|
|
return f"ip:{request.client.host if request.client else 'unknown'}"
|
|
|
|
2. **Use prefixes.** When mixing identifier types, prefix them to avoid collisions:
|
|
|
|
.. code-block:: python
|
|
|
|
# Good - clear what each key represents
|
|
return f"user:{user_id}"
|
|
return f"ip:{ip_address}"
|
|
return f"key:{api_key}"
|
|
|
|
# Bad - could collide
|
|
return user_id
|
|
return ip_address
|
|
|
|
3. **Keep it fast.** The extractor runs on every request. Avoid database lookups
|
|
or expensive operations:
|
|
|
|
.. code-block:: python
|
|
|
|
# Bad - database lookup on every request
|
|
def slow_extractor(request: Request) -> str:
|
|
user = db.get_user(request.headers.get("Authorization"))
|
|
return user.id
|
|
|
|
# Good - use data already in the request
|
|
def fast_extractor(request: Request) -> str:
|
|
return request.state.user.id # Set by auth middleware
|
|
|
|
4. **Be consistent.** The same client should always get the same key. Watch out
|
|
for things like:
|
|
|
|
- IP addresses changing (mobile users)
|
|
- Case sensitivity (normalize to lowercase)
|
|
- Whitespace (strip it)
|
|
|
|
.. code-block:: python
|
|
|
|
def normalized_extractor(request: Request) -> str:
|
|
api_key = request.headers.get("X-API-Key", "").strip().lower()
|
|
if api_key:
|
|
return f"key:{api_key}"
|
|
return f"ip:{request.client.host}"
|
|
|
|
Using with Middleware
|
|
---------------------
|
|
|
|
Key extractors work the same way with middleware:
|
|
|
|
.. code-block:: python
|
|
|
|
from fastapi_traffic.middleware import RateLimitMiddleware
|
|
|
|
app.add_middleware(
|
|
RateLimitMiddleware,
|
|
limit=1000,
|
|
window_size=60,
|
|
key_extractor=api_key_extractor,
|
|
)
|