Initial commit: fastapi-traffic rate limiting library
- Core rate limiting with multiple algorithms (sliding window, token bucket, etc.) - SQLite and memory backends - Decorator and dependency injection patterns - Middleware support - Example usage files
This commit is contained in:
466
fastapi_traffic/core/algorithms.py
Normal file
466
fastapi_traffic/core/algorithms.py
Normal file
@@ -0,0 +1,466 @@
|
||||
"""Rate limiting algorithms implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi_traffic.core.models import RateLimitInfo
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi_traffic.backends.base import Backend
|
||||
|
||||
|
||||
class Algorithm(str, Enum):
|
||||
"""Available rate limiting algorithms."""
|
||||
|
||||
TOKEN_BUCKET = "token_bucket"
|
||||
SLIDING_WINDOW = "sliding_window"
|
||||
FIXED_WINDOW = "fixed_window"
|
||||
LEAKY_BUCKET = "leaky_bucket"
|
||||
SLIDING_WINDOW_COUNTER = "sliding_window_counter"
|
||||
|
||||
|
||||
class BaseAlgorithm(ABC):
|
||||
"""Base class for rate limiting algorithms."""
|
||||
|
||||
__slots__ = ("limit", "window_size", "backend", "burst_size")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
limit: int,
|
||||
window_size: float,
|
||||
backend: Backend,
|
||||
*,
|
||||
burst_size: int | None = None,
|
||||
) -> None:
|
||||
self.limit = limit
|
||||
self.window_size = window_size
|
||||
self.backend = backend
|
||||
self.burst_size = burst_size or limit
|
||||
|
||||
@abstractmethod
|
||||
async def check(self, key: str) -> tuple[bool, RateLimitInfo]:
|
||||
"""Check if request is allowed and update state."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def reset(self, key: str) -> None:
|
||||
"""Reset the rate limit state for a key."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_state(self, key: str) -> RateLimitInfo | None:
|
||||
"""Get current state without consuming a token."""
|
||||
...
|
||||
|
||||
|
||||
class TokenBucketAlgorithm(BaseAlgorithm):
|
||||
"""Token bucket algorithm - allows bursts up to bucket capacity."""
|
||||
|
||||
__slots__ = ("refill_rate",)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
limit: int,
|
||||
window_size: float,
|
||||
backend: Backend,
|
||||
*,
|
||||
burst_size: int | None = None,
|
||||
) -> None:
|
||||
super().__init__(limit, window_size, backend, burst_size=burst_size)
|
||||
self.refill_rate = limit / window_size
|
||||
|
||||
async def check(self, key: str) -> tuple[bool, RateLimitInfo]:
|
||||
now = time.time()
|
||||
state = await self.backend.get(key)
|
||||
|
||||
if state is None:
|
||||
tokens = float(self.burst_size - 1)
|
||||
await self.backend.set(
|
||||
key,
|
||||
{"tokens": tokens, "last_update": now},
|
||||
ttl=self.window_size * 2,
|
||||
)
|
||||
return True, RateLimitInfo(
|
||||
limit=self.limit,
|
||||
remaining=int(tokens),
|
||||
reset_at=now + self.window_size,
|
||||
window_size=self.window_size,
|
||||
)
|
||||
|
||||
tokens = float(state.get("tokens", self.burst_size))
|
||||
last_update = float(state.get("last_update", now))
|
||||
|
||||
elapsed = now - last_update
|
||||
tokens = min(self.burst_size, tokens + elapsed * self.refill_rate)
|
||||
|
||||
if tokens >= 1:
|
||||
tokens -= 1
|
||||
allowed = True
|
||||
retry_after = None
|
||||
else:
|
||||
allowed = False
|
||||
retry_after = (1 - tokens) / self.refill_rate
|
||||
|
||||
await self.backend.set(
|
||||
key,
|
||||
{"tokens": tokens, "last_update": now},
|
||||
ttl=self.window_size * 2,
|
||||
)
|
||||
|
||||
return allowed, RateLimitInfo(
|
||||
limit=self.limit,
|
||||
remaining=int(tokens),
|
||||
reset_at=now + (self.burst_size - tokens) / self.refill_rate,
|
||||
retry_after=retry_after,
|
||||
window_size=self.window_size,
|
||||
)
|
||||
|
||||
async def reset(self, key: str) -> None:
|
||||
await self.backend.delete(key)
|
||||
|
||||
async def get_state(self, key: str) -> RateLimitInfo | None:
|
||||
now = time.time()
|
||||
state = await self.backend.get(key)
|
||||
|
||||
if state is None:
|
||||
return None
|
||||
|
||||
tokens = float(state.get("tokens", self.burst_size))
|
||||
last_update = float(state.get("last_update", now))
|
||||
|
||||
elapsed = now - last_update
|
||||
tokens = min(self.burst_size, tokens + elapsed * self.refill_rate)
|
||||
|
||||
return RateLimitInfo(
|
||||
limit=self.limit,
|
||||
remaining=int(tokens),
|
||||
reset_at=now + (self.burst_size - tokens) / self.refill_rate,
|
||||
window_size=self.window_size,
|
||||
)
|
||||
|
||||
|
||||
class SlidingWindowAlgorithm(BaseAlgorithm):
|
||||
"""Sliding window log algorithm - precise but memory intensive."""
|
||||
|
||||
async def check(self, key: str) -> tuple[bool, RateLimitInfo]:
|
||||
now = time.time()
|
||||
window_start = now - self.window_size
|
||||
state = await self.backend.get(key)
|
||||
|
||||
timestamps: list[float] = []
|
||||
if state is not None:
|
||||
raw_timestamps = state.get("timestamps", [])
|
||||
timestamps = [
|
||||
float(ts) for ts in raw_timestamps if float(ts) > window_start
|
||||
]
|
||||
|
||||
if len(timestamps) < self.limit:
|
||||
timestamps.append(now)
|
||||
allowed = True
|
||||
retry_after = None
|
||||
else:
|
||||
allowed = False
|
||||
oldest = min(timestamps) if timestamps else now
|
||||
retry_after = oldest + self.window_size - now
|
||||
|
||||
await self.backend.set(
|
||||
key,
|
||||
{"timestamps": timestamps},
|
||||
ttl=self.window_size * 2,
|
||||
)
|
||||
|
||||
remaining = max(0, self.limit - len(timestamps))
|
||||
reset_at = (min(timestamps) if timestamps else now) + self.window_size
|
||||
|
||||
return allowed, RateLimitInfo(
|
||||
limit=self.limit,
|
||||
remaining=remaining,
|
||||
reset_at=reset_at,
|
||||
retry_after=retry_after,
|
||||
window_size=self.window_size,
|
||||
)
|
||||
|
||||
async def reset(self, key: str) -> None:
|
||||
await self.backend.delete(key)
|
||||
|
||||
async def get_state(self, key: str) -> RateLimitInfo | None:
|
||||
now = time.time()
|
||||
window_start = now - self.window_size
|
||||
state = await self.backend.get(key)
|
||||
|
||||
if state is None:
|
||||
return None
|
||||
|
||||
raw_timestamps = state.get("timestamps", [])
|
||||
timestamps = [float(ts) for ts in raw_timestamps if float(ts) > window_start]
|
||||
remaining = max(0, self.limit - len(timestamps))
|
||||
reset_at = (min(timestamps) if timestamps else now) + self.window_size
|
||||
|
||||
return RateLimitInfo(
|
||||
limit=self.limit,
|
||||
remaining=remaining,
|
||||
reset_at=reset_at,
|
||||
window_size=self.window_size,
|
||||
)
|
||||
|
||||
|
||||
class FixedWindowAlgorithm(BaseAlgorithm):
|
||||
"""Fixed window algorithm - simple and efficient."""
|
||||
|
||||
async def check(self, key: str) -> tuple[bool, RateLimitInfo]:
|
||||
now = time.time()
|
||||
window_start = (now // self.window_size) * self.window_size
|
||||
window_end = window_start + self.window_size
|
||||
state = await self.backend.get(key)
|
||||
|
||||
count = 0
|
||||
if state is not None:
|
||||
stored_window = float(state.get("window_start", 0))
|
||||
if stored_window == window_start:
|
||||
count = int(state.get("count", 0))
|
||||
|
||||
if count < self.limit:
|
||||
count += 1
|
||||
allowed = True
|
||||
retry_after = None
|
||||
else:
|
||||
allowed = False
|
||||
retry_after = window_end - now
|
||||
|
||||
await self.backend.set(
|
||||
key,
|
||||
{"count": count, "window_start": window_start},
|
||||
ttl=self.window_size * 2,
|
||||
)
|
||||
|
||||
return allowed, RateLimitInfo(
|
||||
limit=self.limit,
|
||||
remaining=max(0, self.limit - count),
|
||||
reset_at=window_end,
|
||||
retry_after=retry_after,
|
||||
window_size=self.window_size,
|
||||
)
|
||||
|
||||
async def reset(self, key: str) -> None:
|
||||
await self.backend.delete(key)
|
||||
|
||||
async def get_state(self, key: str) -> RateLimitInfo | None:
|
||||
now = time.time()
|
||||
window_start = (now // self.window_size) * self.window_size
|
||||
window_end = window_start + self.window_size
|
||||
state = await self.backend.get(key)
|
||||
|
||||
if state is None:
|
||||
return None
|
||||
|
||||
count = 0
|
||||
stored_window = float(state.get("window_start", 0))
|
||||
if stored_window == window_start:
|
||||
count = int(state.get("count", 0))
|
||||
|
||||
return RateLimitInfo(
|
||||
limit=self.limit,
|
||||
remaining=max(0, self.limit - count),
|
||||
reset_at=window_end,
|
||||
window_size=self.window_size,
|
||||
)
|
||||
|
||||
|
||||
class LeakyBucketAlgorithm(BaseAlgorithm):
|
||||
"""Leaky bucket algorithm - smooths out bursts."""
|
||||
|
||||
__slots__ = ("leak_rate",)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
limit: int,
|
||||
window_size: float,
|
||||
backend: Backend,
|
||||
*,
|
||||
burst_size: int | None = None,
|
||||
) -> None:
|
||||
super().__init__(limit, window_size, backend, burst_size=burst_size)
|
||||
self.leak_rate = limit / window_size
|
||||
|
||||
async def check(self, key: str) -> tuple[bool, RateLimitInfo]:
|
||||
now = time.time()
|
||||
state = await self.backend.get(key)
|
||||
|
||||
water_level = 0.0
|
||||
if state is not None:
|
||||
water_level = float(state.get("water_level", 0))
|
||||
last_update = float(state.get("last_update", now))
|
||||
elapsed = now - last_update
|
||||
water_level = max(0, water_level - elapsed * self.leak_rate)
|
||||
|
||||
if water_level < self.burst_size:
|
||||
water_level += 1
|
||||
allowed = True
|
||||
retry_after = None
|
||||
else:
|
||||
allowed = False
|
||||
retry_after = (water_level - self.burst_size + 1) / self.leak_rate
|
||||
|
||||
await self.backend.set(
|
||||
key,
|
||||
{"water_level": water_level, "last_update": now},
|
||||
ttl=self.window_size * 2,
|
||||
)
|
||||
|
||||
remaining = max(0, int(self.burst_size - water_level))
|
||||
reset_at = now + water_level / self.leak_rate
|
||||
|
||||
return allowed, RateLimitInfo(
|
||||
limit=self.limit,
|
||||
remaining=remaining,
|
||||
reset_at=reset_at,
|
||||
retry_after=retry_after,
|
||||
window_size=self.window_size,
|
||||
)
|
||||
|
||||
async def reset(self, key: str) -> None:
|
||||
await self.backend.delete(key)
|
||||
|
||||
async def get_state(self, key: str) -> RateLimitInfo | None:
|
||||
now = time.time()
|
||||
state = await self.backend.get(key)
|
||||
|
||||
if state is None:
|
||||
return None
|
||||
|
||||
water_level = float(state.get("water_level", 0))
|
||||
last_update = float(state.get("last_update", now))
|
||||
elapsed = now - last_update
|
||||
water_level = max(0, water_level - elapsed * self.leak_rate)
|
||||
|
||||
remaining = max(0, int(self.burst_size - water_level))
|
||||
reset_at = now + water_level / self.leak_rate
|
||||
|
||||
return RateLimitInfo(
|
||||
limit=self.limit,
|
||||
remaining=remaining,
|
||||
reset_at=reset_at,
|
||||
window_size=self.window_size,
|
||||
)
|
||||
|
||||
|
||||
class SlidingWindowCounterAlgorithm(BaseAlgorithm):
|
||||
"""Sliding window counter - balance between precision and memory."""
|
||||
|
||||
async def check(self, key: str) -> tuple[bool, RateLimitInfo]:
|
||||
now = time.time()
|
||||
current_window = (now // self.window_size) * self.window_size
|
||||
previous_window = current_window - self.window_size
|
||||
window_progress = (now - current_window) / self.window_size
|
||||
|
||||
state = await self.backend.get(key)
|
||||
|
||||
prev_count = 0
|
||||
curr_count = 0
|
||||
if state is not None:
|
||||
prev_count = int(state.get("prev_count", 0))
|
||||
curr_count = int(state.get("curr_count", 0))
|
||||
stored_window = float(state.get("current_window", 0))
|
||||
|
||||
if stored_window < previous_window:
|
||||
prev_count = 0
|
||||
curr_count = 0
|
||||
elif stored_window == previous_window:
|
||||
prev_count = curr_count
|
||||
curr_count = 0
|
||||
|
||||
weighted_count = prev_count * (1 - window_progress) + curr_count
|
||||
|
||||
if weighted_count < self.limit:
|
||||
curr_count += 1
|
||||
allowed = True
|
||||
retry_after = None
|
||||
else:
|
||||
allowed = False
|
||||
retry_after = self.window_size * (1 - window_progress)
|
||||
|
||||
await self.backend.set(
|
||||
key,
|
||||
{
|
||||
"prev_count": prev_count,
|
||||
"curr_count": curr_count,
|
||||
"current_window": current_window,
|
||||
},
|
||||
ttl=self.window_size * 3,
|
||||
)
|
||||
|
||||
remaining = max(0, int(self.limit - weighted_count))
|
||||
reset_at = current_window + self.window_size
|
||||
|
||||
return allowed, RateLimitInfo(
|
||||
limit=self.limit,
|
||||
remaining=remaining,
|
||||
reset_at=reset_at,
|
||||
retry_after=retry_after,
|
||||
window_size=self.window_size,
|
||||
)
|
||||
|
||||
async def reset(self, key: str) -> None:
|
||||
await self.backend.delete(key)
|
||||
|
||||
async def get_state(self, key: str) -> RateLimitInfo | None:
|
||||
now = time.time()
|
||||
current_window = (now // self.window_size) * self.window_size
|
||||
previous_window = current_window - self.window_size
|
||||
window_progress = (now - current_window) / self.window_size
|
||||
|
||||
state = await self.backend.get(key)
|
||||
|
||||
if state is None:
|
||||
return None
|
||||
|
||||
prev_count = int(state.get("prev_count", 0))
|
||||
curr_count = int(state.get("curr_count", 0))
|
||||
stored_window = float(state.get("current_window", 0))
|
||||
|
||||
if stored_window < previous_window:
|
||||
prev_count = 0
|
||||
curr_count = 0
|
||||
elif stored_window == previous_window:
|
||||
prev_count = curr_count
|
||||
curr_count = 0
|
||||
|
||||
weighted_count = prev_count * (1 - window_progress) + curr_count
|
||||
remaining = max(0, int(self.limit - weighted_count))
|
||||
reset_at = current_window + self.window_size
|
||||
|
||||
return RateLimitInfo(
|
||||
limit=self.limit,
|
||||
remaining=remaining,
|
||||
reset_at=reset_at,
|
||||
window_size=self.window_size,
|
||||
)
|
||||
|
||||
|
||||
def get_algorithm(
|
||||
algorithm: Algorithm,
|
||||
limit: int,
|
||||
window_size: float,
|
||||
backend: Backend,
|
||||
*,
|
||||
burst_size: int | None = None,
|
||||
) -> BaseAlgorithm:
|
||||
"""Factory function to create algorithm instances."""
|
||||
algorithm_map: dict[Algorithm, type[BaseAlgorithm]] = {
|
||||
Algorithm.TOKEN_BUCKET: TokenBucketAlgorithm,
|
||||
Algorithm.SLIDING_WINDOW: SlidingWindowAlgorithm,
|
||||
Algorithm.FIXED_WINDOW: FixedWindowAlgorithm,
|
||||
Algorithm.LEAKY_BUCKET: LeakyBucketAlgorithm,
|
||||
Algorithm.SLIDING_WINDOW_COUNTER: SlidingWindowCounterAlgorithm,
|
||||
}
|
||||
|
||||
algorithm_class = algorithm_map.get(algorithm)
|
||||
if algorithm_class is None:
|
||||
msg = f"Unknown algorithm: {algorithm}"
|
||||
raise ValueError(msg)
|
||||
|
||||
return algorithm_class(limit, window_size, backend, burst_size=burst_size)
|
||||
Reference in New Issue
Block a user