diff --git a/tests/conftest.py b/tests/conftest.py index 218c125..7d05c09 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,7 +3,7 @@ from __future__ import annotations import asyncio -from typing import TYPE_CHECKING, AsyncGenerator, Generator +from typing import TYPE_CHECKING import pytest from fastapi import FastAPI, Request @@ -13,8 +13,8 @@ from httpx import ASGITransport, AsyncClient from fastapi_traffic import ( Algorithm, MemoryBackend, - RateLimitExceeded, RateLimiter, + RateLimitExceeded, SQLiteBackend, rate_limit, ) @@ -23,6 +23,7 @@ from fastapi_traffic.core.limiter import set_limiter from fastapi_traffic.middleware import RateLimitMiddleware if TYPE_CHECKING: + from collections.abc import AsyncGenerator, Generator pass diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py index f50c323..c7e9f09 100644 --- a/tests/test_algorithms.py +++ b/tests/test_algorithms.py @@ -12,8 +12,7 @@ Comprehensive tests covering: from __future__ import annotations import asyncio -import time -from typing import AsyncGenerator +from typing import TYPE_CHECKING import pytest @@ -28,6 +27,9 @@ from fastapi_traffic.core.algorithms import ( get_algorithm, ) +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + @pytest.fixture async def backend() -> AsyncGenerator[MemoryBackend, None]: @@ -41,9 +43,7 @@ async def backend() -> AsyncGenerator[MemoryBackend, None]: class TestTokenBucketAlgorithm: """Tests for TokenBucketAlgorithm.""" - async def test_allows_requests_within_limit( - self, backend: MemoryBackend - ) -> None: + async def test_allows_requests_within_limit(self, backend: MemoryBackend) -> None: """Test that requests within limit are allowed.""" algo = TokenBucketAlgorithm(10, 60.0, backend) @@ -51,9 +51,7 @@ class TestTokenBucketAlgorithm: allowed, _ = await algo.check(f"key_{i % 2}") assert allowed, f"Request {i} should be allowed" - async def test_blocks_requests_over_limit( - self, backend: MemoryBackend - ) -> None: + async def test_blocks_requests_over_limit(self, backend: MemoryBackend) -> None: """Test that requests over limit are blocked.""" algo = TokenBucketAlgorithm(3, 60.0, backend) @@ -86,9 +84,7 @@ class TestTokenBucketAlgorithm: class TestSlidingWindowAlgorithm: """Tests for SlidingWindowAlgorithm.""" - async def test_allows_requests_within_limit( - self, backend: MemoryBackend - ) -> None: + async def test_allows_requests_within_limit(self, backend: MemoryBackend) -> None: """Test that requests within limit are allowed.""" algo = SlidingWindowAlgorithm(5, 60.0, backend) @@ -96,9 +92,7 @@ class TestSlidingWindowAlgorithm: allowed, _ = await algo.check("test_key") assert allowed - async def test_blocks_requests_over_limit( - self, backend: MemoryBackend - ) -> None: + async def test_blocks_requests_over_limit(self, backend: MemoryBackend) -> None: """Test that requests over limit are blocked.""" algo = SlidingWindowAlgorithm(3, 60.0, backend) @@ -115,9 +109,7 @@ class TestSlidingWindowAlgorithm: class TestFixedWindowAlgorithm: """Tests for FixedWindowAlgorithm.""" - async def test_allows_requests_within_limit( - self, backend: MemoryBackend - ) -> None: + async def test_allows_requests_within_limit(self, backend: MemoryBackend) -> None: """Test that requests within limit are allowed.""" algo = FixedWindowAlgorithm(5, 60.0, backend) @@ -125,9 +117,7 @@ class TestFixedWindowAlgorithm: allowed, _ = await algo.check("test_key") assert allowed - async def test_blocks_requests_over_limit( - self, backend: MemoryBackend - ) -> None: + async def test_blocks_requests_over_limit(self, backend: MemoryBackend) -> None: """Test that requests over limit are blocked.""" algo = FixedWindowAlgorithm(3, 60.0, backend) @@ -144,9 +134,7 @@ class TestFixedWindowAlgorithm: class TestLeakyBucketAlgorithm: """Tests for LeakyBucketAlgorithm.""" - async def test_allows_requests_within_limit( - self, backend: MemoryBackend - ) -> None: + async def test_allows_requests_within_limit(self, backend: MemoryBackend) -> None: """Test that requests within limit are allowed.""" algo = LeakyBucketAlgorithm(5, 60.0, backend) @@ -154,9 +142,7 @@ class TestLeakyBucketAlgorithm: allowed, _ = await algo.check("test_key") assert allowed - async def test_blocks_requests_over_limit( - self, backend: MemoryBackend - ) -> None: + async def test_blocks_requests_over_limit(self, backend: MemoryBackend) -> None: """Test that requests over limit are blocked.""" algo = LeakyBucketAlgorithm(3, 60.0, backend) @@ -176,9 +162,7 @@ class TestLeakyBucketAlgorithm: class TestSlidingWindowCounterAlgorithm: """Tests for SlidingWindowCounterAlgorithm.""" - async def test_allows_requests_within_limit( - self, backend: MemoryBackend - ) -> None: + async def test_allows_requests_within_limit(self, backend: MemoryBackend) -> None: """Test that requests within limit are allowed.""" algo = SlidingWindowCounterAlgorithm(5, 60.0, backend) @@ -186,9 +170,7 @@ class TestSlidingWindowCounterAlgorithm: allowed, _ = await algo.check("test_key") assert allowed - async def test_blocks_requests_over_limit( - self, backend: MemoryBackend - ) -> None: + async def test_blocks_requests_over_limit(self, backend: MemoryBackend) -> None: """Test that requests over limit are blocked.""" algo = SlidingWindowCounterAlgorithm(3, 60.0, backend) @@ -224,9 +206,7 @@ class TestGetAlgorithm: algo = get_algorithm(Algorithm.LEAKY_BUCKET, 10, 60.0, backend) assert isinstance(algo, LeakyBucketAlgorithm) - async def test_get_sliding_window_counter( - self, backend: MemoryBackend - ) -> None: + async def test_get_sliding_window_counter(self, backend: MemoryBackend) -> None: """Test getting sliding window counter algorithm.""" algo = get_algorithm(Algorithm.SLIDING_WINDOW_COUNTER, 10, 60.0, backend) assert isinstance(algo, SlidingWindowCounterAlgorithm) @@ -477,9 +457,7 @@ class TestAlgorithmStateManagement: state = await algo.get_state("nonexistent_key") assert state is None - async def test_reset_restores_full_capacity( - self, backend: MemoryBackend - ) -> None: + async def test_reset_restores_full_capacity(self, backend: MemoryBackend) -> None: """Test that reset restores full capacity.""" algo = TokenBucketAlgorithm(5, 60.0, backend) diff --git a/tests/test_backends.py b/tests/test_backends.py index a461e55..dd71006 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -13,13 +13,16 @@ Comprehensive tests covering: from __future__ import annotations import asyncio -from typing import AsyncGenerator +from typing import TYPE_CHECKING import pytest from fastapi_traffic.backends.memory import MemoryBackend from fastapi_traffic.backends.sqlite import SQLiteBackend +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + @pytest.mark.asyncio class TestMemoryBackend: @@ -163,6 +166,7 @@ class TestMemoryBackendAdvanced: """Test concurrent write operations don't corrupt data.""" backend = MemoryBackend(max_size=1000) try: + async def write_key(i: int) -> None: await backend.set(f"key_{i}", {"value": i}, ttl=60.0) @@ -302,6 +306,7 @@ class TestSQLiteBackendAdvanced: backend = SQLiteBackend(":memory:") await backend.initialize() try: + async def write_key(i: int) -> None: await backend.set(f"key_{i}", {"value": i}, ttl=60.0) @@ -377,7 +382,9 @@ class TestBackendInterface: """Tests to verify backend interface consistency.""" @pytest.fixture - async def backends(self) -> AsyncGenerator[list[MemoryBackend | SQLiteBackend], None]: + async def backends( + self, + ) -> AsyncGenerator[list[MemoryBackend | SQLiteBackend], None]: """Create all backend types for testing.""" memory = MemoryBackend() sqlite = SQLiteBackend(":memory:") diff --git a/tests/test_decorator.py b/tests/test_decorator.py index 04909ac..aedbd06 100644 --- a/tests/test_decorator.py +++ b/tests/test_decorator.py @@ -14,7 +14,7 @@ Comprehensive tests covering: from __future__ import annotations -from typing import AsyncGenerator +from typing import TYPE_CHECKING import pytest from fastapi import FastAPI, Request @@ -23,12 +23,15 @@ from httpx import ASGITransport, AsyncClient from fastapi_traffic import ( MemoryBackend, - RateLimitExceeded, RateLimiter, + RateLimitExceeded, rate_limit, ) from fastapi_traffic.core.limiter import set_limiter +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + class TestRateLimitDecorator: """Tests for the @rate_limit decorator.""" @@ -175,9 +178,7 @@ class TestCustomKeyExtractor: ) -> None: """Test that different API keys have separate rate limits.""" for _ in range(2): - response = await client.get( - "/by-api-key", headers={"X-API-Key": "key-a"} - ) + response = await client.get("/by-api-key", headers={"X-API-Key": "key-a"}) assert response.status_code == 200 response = await client.get("/by-api-key", headers={"X-API-Key": "key-a"}) @@ -186,9 +187,7 @@ class TestCustomKeyExtractor: response = await client.get("/by-api-key", headers={"X-API-Key": "key-b"}) assert response.status_code == 200 - async def test_anonymous_key_for_missing_header( - self, client: AsyncClient - ) -> None: + async def test_anonymous_key_for_missing_header(self, client: AsyncClient) -> None: """Test that missing API key uses anonymous.""" for _ in range(2): response = await client.get("/by-api-key") @@ -255,8 +254,6 @@ class TestExemptionCallback: assert response.status_code == 429 - - class TestCostParameter: """Tests for the cost parameter.""" diff --git a/tests/test_integration.py b/tests/test_integration.py index 0803b6f..d53cf68 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -8,7 +8,7 @@ from __future__ import annotations import asyncio from contextlib import asynccontextmanager -from typing import AsyncGenerator +from typing import TYPE_CHECKING import pytest from fastapi import FastAPI, Request @@ -18,14 +18,17 @@ from httpx import ASGITransport, AsyncClient from fastapi_traffic import ( Algorithm, MemoryBackend, - RateLimitExceeded, RateLimiter, + RateLimitExceeded, rate_limit, ) from fastapi_traffic.core.config import RateLimitConfig from fastapi_traffic.core.limiter import set_limiter from fastapi_traffic.middleware import RateLimitMiddleware +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + class TestFullApplicationFlow: """Integration tests for a complete application setup.""" @@ -128,9 +131,7 @@ class TestFullApplicationFlow: ) assert response.status_code == 200 - async def test_basic_rate_limiting_works( - self, client: AsyncClient - ) -> None: + async def test_basic_rate_limiting_works(self, client: AsyncClient) -> None: """Test that basic rate limiting is functional.""" # Make a request and verify it works response = await client.get("/api/v1/users/1") diff --git a/tests/test_middleware.py b/tests/test_middleware.py index ca190b7..34e253d 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -13,7 +13,7 @@ Comprehensive tests covering: from __future__ import annotations -from typing import AsyncGenerator +from typing import TYPE_CHECKING import pytest from fastapi import FastAPI @@ -26,6 +26,9 @@ from fastapi_traffic.middleware import ( TokenBucketMiddleware, ) +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + class TestRateLimitMiddleware: """Tests for RateLimitMiddleware.""" @@ -81,7 +84,9 @@ class TestRateLimitMiddleware: assert "X-RateLimit-Remaining" in response.headers assert "X-RateLimit-Reset" in response.headers - async def test_different_endpoints_counted_separately(self, client: AsyncClient) -> None: + async def test_different_endpoints_counted_separately( + self, client: AsyncClient + ) -> None: """Test that different endpoints are counted separately by path.""" # Middleware includes path in the key by default for _ in range(3): @@ -224,14 +229,10 @@ class TestMiddlewareCustomKeyExtractor: ) assert response.status_code == 200 - response = await client.get( - "/api/resource", headers={"X-User-ID": "user-1"} - ) + response = await client.get("/api/resource", headers={"X-User-ID": "user-1"}) assert response.status_code == 429 - response = await client.get( - "/api/resource", headers={"X-User-ID": "user-2"} - ) + response = await client.get("/api/resource", headers={"X-User-ID": "user-2"}) assert response.status_code == 200 @@ -313,7 +314,9 @@ class TestMiddlewareErrorHandling: return app @pytest.fixture - async def client(self, app_skip_on_error: FastAPI) -> AsyncGenerator[AsyncClient, None]: + async def client( + self, app_skip_on_error: FastAPI + ) -> AsyncGenerator[AsyncClient, None]: """Create test client.""" transport = ASGITransport(app=app_skip_on_error) async with AsyncClient(transport=transport, base_url="http://test") as client: