"""Rate limiting algorithms implementation.""" from __future__ import annotations import time from abc import ABC, abstractmethod from enum import Enum from typing import TYPE_CHECKING from fastapi_traffic.core.models import RateLimitInfo if TYPE_CHECKING: from fastapi_traffic.backends.base import Backend class Algorithm(str, Enum): """Available rate limiting algorithms.""" TOKEN_BUCKET = "token_bucket" SLIDING_WINDOW = "sliding_window" FIXED_WINDOW = "fixed_window" LEAKY_BUCKET = "leaky_bucket" SLIDING_WINDOW_COUNTER = "sliding_window_counter" class BaseAlgorithm(ABC): """Base class for rate limiting algorithms.""" __slots__ = ("backend", "burst_size", "limit", "window_size") def __init__( self, limit: int, window_size: float, backend: Backend, *, burst_size: int | None = None, ) -> None: self.limit = limit self.window_size = window_size self.backend = backend self.burst_size = burst_size or limit @abstractmethod async def check(self, key: str) -> tuple[bool, RateLimitInfo]: """Check if request is allowed and update state.""" ... @abstractmethod async def reset(self, key: str) -> None: """Reset the rate limit state for a key.""" ... @abstractmethod async def get_state(self, key: str) -> RateLimitInfo | None: """Get current state without consuming a token.""" ... class TokenBucketAlgorithm(BaseAlgorithm): """Token bucket algorithm - allows bursts up to bucket capacity.""" __slots__ = ("refill_rate",) def __init__( self, limit: int, window_size: float, backend: Backend, *, burst_size: int | None = None, ) -> None: super().__init__(limit, window_size, backend, burst_size=burst_size) self.refill_rate = limit / window_size async def check(self, key: str) -> tuple[bool, RateLimitInfo]: now = time.time() state = await self.backend.get(key) if state is None: tokens = float(self.burst_size - 1) await self.backend.set( key, {"tokens": tokens, "last_update": now}, ttl=self.window_size * 2, ) return True, RateLimitInfo( limit=self.limit, remaining=int(tokens), reset_at=now + self.window_size, window_size=self.window_size, ) tokens = float(state.get("tokens", self.burst_size)) last_update = float(state.get("last_update", now)) elapsed = now - last_update tokens = min(self.burst_size, tokens + elapsed * self.refill_rate) if tokens >= 1: tokens -= 1 allowed = True retry_after = None else: allowed = False retry_after = (1 - tokens) / self.refill_rate await self.backend.set( key, {"tokens": tokens, "last_update": now}, ttl=self.window_size * 2, ) return allowed, RateLimitInfo( limit=self.limit, remaining=int(tokens), reset_at=now + (self.burst_size - tokens) / self.refill_rate, retry_after=retry_after, window_size=self.window_size, ) async def reset(self, key: str) -> None: await self.backend.delete(key) async def get_state(self, key: str) -> RateLimitInfo | None: now = time.time() state = await self.backend.get(key) if state is None: return None tokens = float(state.get("tokens", self.burst_size)) last_update = float(state.get("last_update", now)) elapsed = now - last_update tokens = min(self.burst_size, tokens + elapsed * self.refill_rate) return RateLimitInfo( limit=self.limit, remaining=int(tokens), reset_at=now + (self.burst_size - tokens) / self.refill_rate, window_size=self.window_size, ) class SlidingWindowAlgorithm(BaseAlgorithm): """Sliding window log algorithm - precise but memory intensive.""" async def check(self, key: str) -> tuple[bool, RateLimitInfo]: now = time.time() window_start = now - self.window_size state = await self.backend.get(key) timestamps: list[float] = [] if state is not None: raw_timestamps = state.get("timestamps", []) timestamps = [ float(ts) for ts in raw_timestamps if float(ts) > window_start ] if len(timestamps) < self.limit: timestamps.append(now) allowed = True retry_after = None else: allowed = False oldest = min(timestamps) if timestamps else now retry_after = oldest + self.window_size - now await self.backend.set( key, {"timestamps": timestamps}, ttl=self.window_size * 2, ) remaining = max(0, self.limit - len(timestamps)) reset_at = (min(timestamps) if timestamps else now) + self.window_size return allowed, RateLimitInfo( limit=self.limit, remaining=remaining, reset_at=reset_at, retry_after=retry_after, window_size=self.window_size, ) async def reset(self, key: str) -> None: await self.backend.delete(key) async def get_state(self, key: str) -> RateLimitInfo | None: now = time.time() window_start = now - self.window_size state = await self.backend.get(key) if state is None: return None raw_timestamps = state.get("timestamps", []) timestamps = [float(ts) for ts in raw_timestamps if float(ts) > window_start] remaining = max(0, self.limit - len(timestamps)) reset_at = (min(timestamps) if timestamps else now) + self.window_size return RateLimitInfo( limit=self.limit, remaining=remaining, reset_at=reset_at, window_size=self.window_size, ) class FixedWindowAlgorithm(BaseAlgorithm): """Fixed window algorithm - simple and efficient.""" async def check(self, key: str) -> tuple[bool, RateLimitInfo]: now = time.time() window_start = (now // self.window_size) * self.window_size window_end = window_start + self.window_size state = await self.backend.get(key) count = 0 if state is not None: stored_window = float(state.get("window_start", 0)) if stored_window == window_start: count = int(state.get("count", 0)) if count < self.limit: count += 1 allowed = True retry_after = None else: allowed = False retry_after = window_end - now await self.backend.set( key, {"count": count, "window_start": window_start}, ttl=self.window_size * 2, ) return allowed, RateLimitInfo( limit=self.limit, remaining=max(0, self.limit - count), reset_at=window_end, retry_after=retry_after, window_size=self.window_size, ) async def reset(self, key: str) -> None: await self.backend.delete(key) async def get_state(self, key: str) -> RateLimitInfo | None: now = time.time() window_start = (now // self.window_size) * self.window_size window_end = window_start + self.window_size state = await self.backend.get(key) if state is None: return None count = 0 stored_window = float(state.get("window_start", 0)) if stored_window == window_start: count = int(state.get("count", 0)) return RateLimitInfo( limit=self.limit, remaining=max(0, self.limit - count), reset_at=window_end, window_size=self.window_size, ) class LeakyBucketAlgorithm(BaseAlgorithm): """Leaky bucket algorithm - smooths out bursts.""" __slots__ = ("leak_rate",) def __init__( self, limit: int, window_size: float, backend: Backend, *, burst_size: int | None = None, ) -> None: super().__init__(limit, window_size, backend, burst_size=burst_size) self.leak_rate = limit / window_size async def check(self, key: str) -> tuple[bool, RateLimitInfo]: now = time.time() state = await self.backend.get(key) water_level = 0.0 if state is not None: water_level = float(state.get("water_level", 0)) last_update = float(state.get("last_update", now)) elapsed = now - last_update water_level = max(0, water_level - elapsed * self.leak_rate) if water_level < self.burst_size: water_level += 1 allowed = True retry_after = None else: allowed = False retry_after = (water_level - self.burst_size + 1) / self.leak_rate await self.backend.set( key, {"water_level": water_level, "last_update": now}, ttl=self.window_size * 2, ) remaining = max(0, int(self.burst_size - water_level)) reset_at = now + water_level / self.leak_rate return allowed, RateLimitInfo( limit=self.limit, remaining=remaining, reset_at=reset_at, retry_after=retry_after, window_size=self.window_size, ) async def reset(self, key: str) -> None: await self.backend.delete(key) async def get_state(self, key: str) -> RateLimitInfo | None: now = time.time() state = await self.backend.get(key) if state is None: return None water_level = float(state.get("water_level", 0)) last_update = float(state.get("last_update", now)) elapsed = now - last_update water_level = max(0, water_level - elapsed * self.leak_rate) remaining = max(0, int(self.burst_size - water_level)) reset_at = now + water_level / self.leak_rate return RateLimitInfo( limit=self.limit, remaining=remaining, reset_at=reset_at, window_size=self.window_size, ) class SlidingWindowCounterAlgorithm(BaseAlgorithm): """Sliding window counter - balance between precision and memory.""" async def check(self, key: str) -> tuple[bool, RateLimitInfo]: now = time.time() current_window = (now // self.window_size) * self.window_size previous_window = current_window - self.window_size window_progress = (now - current_window) / self.window_size state = await self.backend.get(key) prev_count = 0 curr_count = 0 if state is not None: prev_count = int(state.get("prev_count", 0)) curr_count = int(state.get("curr_count", 0)) stored_window = float(state.get("current_window", 0)) if stored_window < previous_window: prev_count = 0 curr_count = 0 elif stored_window == previous_window: prev_count = curr_count curr_count = 0 weighted_count = prev_count * (1 - window_progress) + curr_count if weighted_count < self.limit: curr_count += 1 allowed = True retry_after = None else: allowed = False retry_after = self.window_size * (1 - window_progress) await self.backend.set( key, { "prev_count": prev_count, "curr_count": curr_count, "current_window": current_window, }, ttl=self.window_size * 3, ) remaining = max(0, int(self.limit - weighted_count)) reset_at = current_window + self.window_size return allowed, RateLimitInfo( limit=self.limit, remaining=remaining, reset_at=reset_at, retry_after=retry_after, window_size=self.window_size, ) async def reset(self, key: str) -> None: await self.backend.delete(key) async def get_state(self, key: str) -> RateLimitInfo | None: now = time.time() current_window = (now // self.window_size) * self.window_size previous_window = current_window - self.window_size window_progress = (now - current_window) / self.window_size state = await self.backend.get(key) if state is None: return None prev_count = int(state.get("prev_count", 0)) curr_count = int(state.get("curr_count", 0)) stored_window = float(state.get("current_window", 0)) if stored_window < previous_window: prev_count = 0 curr_count = 0 elif stored_window == previous_window: prev_count = curr_count curr_count = 0 weighted_count = prev_count * (1 - window_progress) + curr_count remaining = max(0, int(self.limit - weighted_count)) reset_at = current_window + self.window_size return RateLimitInfo( limit=self.limit, remaining=remaining, reset_at=reset_at, window_size=self.window_size, ) def get_algorithm( algorithm: Algorithm, limit: int, window_size: float, backend: Backend, *, burst_size: int | None = None, ) -> BaseAlgorithm: """Factory function to create algorithm instances.""" match algorithm: case Algorithm.TOKEN_BUCKET: return TokenBucketAlgorithm( limit, window_size, backend, burst_size=burst_size ) case Algorithm.SLIDING_WINDOW: return SlidingWindowAlgorithm( limit, window_size, backend, burst_size=burst_size ) case Algorithm.FIXED_WINDOW: return FixedWindowAlgorithm( limit, window_size, backend, burst_size=burst_size ) case Algorithm.LEAKY_BUCKET: return LeakyBucketAlgorithm( limit, window_size, backend, burst_size=burst_size ) case Algorithm.SLIDING_WINDOW_COUNTER: return SlidingWindowCounterAlgorithm( limit, window_size, backend, burst_size=burst_size )