style: apply ruff formatting and move TYPE_CHECKING imports in tests
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, Generator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, Request
|
||||
@@ -13,8 +13,8 @@ from httpx import ASGITransport, AsyncClient
|
||||
from fastapi_traffic import (
|
||||
Algorithm,
|
||||
MemoryBackend,
|
||||
RateLimitExceeded,
|
||||
RateLimiter,
|
||||
RateLimitExceeded,
|
||||
SQLiteBackend,
|
||||
rate_limit,
|
||||
)
|
||||
@@ -23,6 +23,7 @@ from fastapi_traffic.core.limiter import set_limiter
|
||||
from fastapi_traffic.middleware import RateLimitMiddleware
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncGenerator, Generator
|
||||
pass
|
||||
|
||||
|
||||
|
||||
@@ -12,8 +12,7 @@ Comprehensive tests covering:
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import AsyncGenerator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -28,6 +27,9 @@ from fastapi_traffic.core.algorithms import (
|
||||
get_algorithm,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def backend() -> AsyncGenerator[MemoryBackend, None]:
|
||||
@@ -41,9 +43,7 @@ async def backend() -> AsyncGenerator[MemoryBackend, None]:
|
||||
class TestTokenBucketAlgorithm:
|
||||
"""Tests for TokenBucketAlgorithm."""
|
||||
|
||||
async def test_allows_requests_within_limit(
|
||||
self, backend: MemoryBackend
|
||||
) -> None:
|
||||
async def test_allows_requests_within_limit(self, backend: MemoryBackend) -> None:
|
||||
"""Test that requests within limit are allowed."""
|
||||
algo = TokenBucketAlgorithm(10, 60.0, backend)
|
||||
|
||||
@@ -51,9 +51,7 @@ class TestTokenBucketAlgorithm:
|
||||
allowed, _ = await algo.check(f"key_{i % 2}")
|
||||
assert allowed, f"Request {i} should be allowed"
|
||||
|
||||
async def test_blocks_requests_over_limit(
|
||||
self, backend: MemoryBackend
|
||||
) -> None:
|
||||
async def test_blocks_requests_over_limit(self, backend: MemoryBackend) -> None:
|
||||
"""Test that requests over limit are blocked."""
|
||||
algo = TokenBucketAlgorithm(3, 60.0, backend)
|
||||
|
||||
@@ -86,9 +84,7 @@ class TestTokenBucketAlgorithm:
|
||||
class TestSlidingWindowAlgorithm:
|
||||
"""Tests for SlidingWindowAlgorithm."""
|
||||
|
||||
async def test_allows_requests_within_limit(
|
||||
self, backend: MemoryBackend
|
||||
) -> None:
|
||||
async def test_allows_requests_within_limit(self, backend: MemoryBackend) -> None:
|
||||
"""Test that requests within limit are allowed."""
|
||||
algo = SlidingWindowAlgorithm(5, 60.0, backend)
|
||||
|
||||
@@ -96,9 +92,7 @@ class TestSlidingWindowAlgorithm:
|
||||
allowed, _ = await algo.check("test_key")
|
||||
assert allowed
|
||||
|
||||
async def test_blocks_requests_over_limit(
|
||||
self, backend: MemoryBackend
|
||||
) -> None:
|
||||
async def test_blocks_requests_over_limit(self, backend: MemoryBackend) -> None:
|
||||
"""Test that requests over limit are blocked."""
|
||||
algo = SlidingWindowAlgorithm(3, 60.0, backend)
|
||||
|
||||
@@ -115,9 +109,7 @@ class TestSlidingWindowAlgorithm:
|
||||
class TestFixedWindowAlgorithm:
|
||||
"""Tests for FixedWindowAlgorithm."""
|
||||
|
||||
async def test_allows_requests_within_limit(
|
||||
self, backend: MemoryBackend
|
||||
) -> None:
|
||||
async def test_allows_requests_within_limit(self, backend: MemoryBackend) -> None:
|
||||
"""Test that requests within limit are allowed."""
|
||||
algo = FixedWindowAlgorithm(5, 60.0, backend)
|
||||
|
||||
@@ -125,9 +117,7 @@ class TestFixedWindowAlgorithm:
|
||||
allowed, _ = await algo.check("test_key")
|
||||
assert allowed
|
||||
|
||||
async def test_blocks_requests_over_limit(
|
||||
self, backend: MemoryBackend
|
||||
) -> None:
|
||||
async def test_blocks_requests_over_limit(self, backend: MemoryBackend) -> None:
|
||||
"""Test that requests over limit are blocked."""
|
||||
algo = FixedWindowAlgorithm(3, 60.0, backend)
|
||||
|
||||
@@ -144,9 +134,7 @@ class TestFixedWindowAlgorithm:
|
||||
class TestLeakyBucketAlgorithm:
|
||||
"""Tests for LeakyBucketAlgorithm."""
|
||||
|
||||
async def test_allows_requests_within_limit(
|
||||
self, backend: MemoryBackend
|
||||
) -> None:
|
||||
async def test_allows_requests_within_limit(self, backend: MemoryBackend) -> None:
|
||||
"""Test that requests within limit are allowed."""
|
||||
algo = LeakyBucketAlgorithm(5, 60.0, backend)
|
||||
|
||||
@@ -154,9 +142,7 @@ class TestLeakyBucketAlgorithm:
|
||||
allowed, _ = await algo.check("test_key")
|
||||
assert allowed
|
||||
|
||||
async def test_blocks_requests_over_limit(
|
||||
self, backend: MemoryBackend
|
||||
) -> None:
|
||||
async def test_blocks_requests_over_limit(self, backend: MemoryBackend) -> None:
|
||||
"""Test that requests over limit are blocked."""
|
||||
algo = LeakyBucketAlgorithm(3, 60.0, backend)
|
||||
|
||||
@@ -176,9 +162,7 @@ class TestLeakyBucketAlgorithm:
|
||||
class TestSlidingWindowCounterAlgorithm:
|
||||
"""Tests for SlidingWindowCounterAlgorithm."""
|
||||
|
||||
async def test_allows_requests_within_limit(
|
||||
self, backend: MemoryBackend
|
||||
) -> None:
|
||||
async def test_allows_requests_within_limit(self, backend: MemoryBackend) -> None:
|
||||
"""Test that requests within limit are allowed."""
|
||||
algo = SlidingWindowCounterAlgorithm(5, 60.0, backend)
|
||||
|
||||
@@ -186,9 +170,7 @@ class TestSlidingWindowCounterAlgorithm:
|
||||
allowed, _ = await algo.check("test_key")
|
||||
assert allowed
|
||||
|
||||
async def test_blocks_requests_over_limit(
|
||||
self, backend: MemoryBackend
|
||||
) -> None:
|
||||
async def test_blocks_requests_over_limit(self, backend: MemoryBackend) -> None:
|
||||
"""Test that requests over limit are blocked."""
|
||||
algo = SlidingWindowCounterAlgorithm(3, 60.0, backend)
|
||||
|
||||
@@ -224,9 +206,7 @@ class TestGetAlgorithm:
|
||||
algo = get_algorithm(Algorithm.LEAKY_BUCKET, 10, 60.0, backend)
|
||||
assert isinstance(algo, LeakyBucketAlgorithm)
|
||||
|
||||
async def test_get_sliding_window_counter(
|
||||
self, backend: MemoryBackend
|
||||
) -> None:
|
||||
async def test_get_sliding_window_counter(self, backend: MemoryBackend) -> None:
|
||||
"""Test getting sliding window counter algorithm."""
|
||||
algo = get_algorithm(Algorithm.SLIDING_WINDOW_COUNTER, 10, 60.0, backend)
|
||||
assert isinstance(algo, SlidingWindowCounterAlgorithm)
|
||||
@@ -477,9 +457,7 @@ class TestAlgorithmStateManagement:
|
||||
state = await algo.get_state("nonexistent_key")
|
||||
assert state is None
|
||||
|
||||
async def test_reset_restores_full_capacity(
|
||||
self, backend: MemoryBackend
|
||||
) -> None:
|
||||
async def test_reset_restores_full_capacity(self, backend: MemoryBackend) -> None:
|
||||
"""Test that reset restores full capacity."""
|
||||
algo = TokenBucketAlgorithm(5, 60.0, backend)
|
||||
|
||||
|
||||
@@ -13,13 +13,16 @@ Comprehensive tests covering:
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import AsyncGenerator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
|
||||
from fastapi_traffic.backends.memory import MemoryBackend
|
||||
from fastapi_traffic.backends.sqlite import SQLiteBackend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestMemoryBackend:
|
||||
@@ -163,6 +166,7 @@ class TestMemoryBackendAdvanced:
|
||||
"""Test concurrent write operations don't corrupt data."""
|
||||
backend = MemoryBackend(max_size=1000)
|
||||
try:
|
||||
|
||||
async def write_key(i: int) -> None:
|
||||
await backend.set(f"key_{i}", {"value": i}, ttl=60.0)
|
||||
|
||||
@@ -302,6 +306,7 @@ class TestSQLiteBackendAdvanced:
|
||||
backend = SQLiteBackend(":memory:")
|
||||
await backend.initialize()
|
||||
try:
|
||||
|
||||
async def write_key(i: int) -> None:
|
||||
await backend.set(f"key_{i}", {"value": i}, ttl=60.0)
|
||||
|
||||
@@ -377,7 +382,9 @@ class TestBackendInterface:
|
||||
"""Tests to verify backend interface consistency."""
|
||||
|
||||
@pytest.fixture
|
||||
async def backends(self) -> AsyncGenerator[list[MemoryBackend | SQLiteBackend], None]:
|
||||
async def backends(
|
||||
self,
|
||||
) -> AsyncGenerator[list[MemoryBackend | SQLiteBackend], None]:
|
||||
"""Create all backend types for testing."""
|
||||
memory = MemoryBackend()
|
||||
sqlite = SQLiteBackend(":memory:")
|
||||
|
||||
@@ -14,7 +14,7 @@ Comprehensive tests covering:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import AsyncGenerator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, Request
|
||||
@@ -23,12 +23,15 @@ from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from fastapi_traffic import (
|
||||
MemoryBackend,
|
||||
RateLimitExceeded,
|
||||
RateLimiter,
|
||||
RateLimitExceeded,
|
||||
rate_limit,
|
||||
)
|
||||
from fastapi_traffic.core.limiter import set_limiter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
|
||||
class TestRateLimitDecorator:
|
||||
"""Tests for the @rate_limit decorator."""
|
||||
@@ -175,9 +178,7 @@ class TestCustomKeyExtractor:
|
||||
) -> None:
|
||||
"""Test that different API keys have separate rate limits."""
|
||||
for _ in range(2):
|
||||
response = await client.get(
|
||||
"/by-api-key", headers={"X-API-Key": "key-a"}
|
||||
)
|
||||
response = await client.get("/by-api-key", headers={"X-API-Key": "key-a"})
|
||||
assert response.status_code == 200
|
||||
|
||||
response = await client.get("/by-api-key", headers={"X-API-Key": "key-a"})
|
||||
@@ -186,9 +187,7 @@ class TestCustomKeyExtractor:
|
||||
response = await client.get("/by-api-key", headers={"X-API-Key": "key-b"})
|
||||
assert response.status_code == 200
|
||||
|
||||
async def test_anonymous_key_for_missing_header(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
async def test_anonymous_key_for_missing_header(self, client: AsyncClient) -> None:
|
||||
"""Test that missing API key uses anonymous."""
|
||||
for _ in range(2):
|
||||
response = await client.get("/by-api-key")
|
||||
@@ -255,8 +254,6 @@ class TestExemptionCallback:
|
||||
assert response.status_code == 429
|
||||
|
||||
|
||||
|
||||
|
||||
class TestCostParameter:
|
||||
"""Tests for the cost parameter."""
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, Request
|
||||
@@ -18,14 +18,17 @@ from httpx import ASGITransport, AsyncClient
|
||||
from fastapi_traffic import (
|
||||
Algorithm,
|
||||
MemoryBackend,
|
||||
RateLimitExceeded,
|
||||
RateLimiter,
|
||||
RateLimitExceeded,
|
||||
rate_limit,
|
||||
)
|
||||
from fastapi_traffic.core.config import RateLimitConfig
|
||||
from fastapi_traffic.core.limiter import set_limiter
|
||||
from fastapi_traffic.middleware import RateLimitMiddleware
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
|
||||
class TestFullApplicationFlow:
|
||||
"""Integration tests for a complete application setup."""
|
||||
@@ -128,9 +131,7 @@ class TestFullApplicationFlow:
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
async def test_basic_rate_limiting_works(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
async def test_basic_rate_limiting_works(self, client: AsyncClient) -> None:
|
||||
"""Test that basic rate limiting is functional."""
|
||||
# Make a request and verify it works
|
||||
response = await client.get("/api/v1/users/1")
|
||||
|
||||
@@ -13,7 +13,7 @@ Comprehensive tests covering:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import AsyncGenerator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
@@ -26,6 +26,9 @@ from fastapi_traffic.middleware import (
|
||||
TokenBucketMiddleware,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
|
||||
class TestRateLimitMiddleware:
|
||||
"""Tests for RateLimitMiddleware."""
|
||||
@@ -81,7 +84,9 @@ class TestRateLimitMiddleware:
|
||||
assert "X-RateLimit-Remaining" in response.headers
|
||||
assert "X-RateLimit-Reset" in response.headers
|
||||
|
||||
async def test_different_endpoints_counted_separately(self, client: AsyncClient) -> None:
|
||||
async def test_different_endpoints_counted_separately(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
"""Test that different endpoints are counted separately by path."""
|
||||
# Middleware includes path in the key by default
|
||||
for _ in range(3):
|
||||
@@ -224,14 +229,10 @@ class TestMiddlewareCustomKeyExtractor:
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
response = await client.get(
|
||||
"/api/resource", headers={"X-User-ID": "user-1"}
|
||||
)
|
||||
response = await client.get("/api/resource", headers={"X-User-ID": "user-1"})
|
||||
assert response.status_code == 429
|
||||
|
||||
response = await client.get(
|
||||
"/api/resource", headers={"X-User-ID": "user-2"}
|
||||
)
|
||||
response = await client.get("/api/resource", headers={"X-User-ID": "user-2"})
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@@ -313,7 +314,9 @@ class TestMiddlewareErrorHandling:
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
async def client(self, app_skip_on_error: FastAPI) -> AsyncGenerator[AsyncClient, None]:
|
||||
async def client(
|
||||
self, app_skip_on_error: FastAPI
|
||||
) -> AsyncGenerator[AsyncClient, None]:
|
||||
"""Create test client."""
|
||||
transport = ASGITransport(app=app_skip_on_error)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
|
||||
Reference in New Issue
Block a user