(Feat): Initial Commit
This commit is contained in:
374
discord_bot/sdk/client.py
Normal file
374
discord_bot/sdk/client.py
Normal file
@@ -0,0 +1,374 @@
|
||||
"""Async client for the Summarizer API."""
|
||||
from __future__ import annotations
|
||||
|
||||
from types import TracebackType
|
||||
from typing import Any, Self
|
||||
|
||||
import httpx
|
||||
|
||||
from .exceptions import (
|
||||
SummarizerAPIError,
|
||||
SummarizerBadRequestError,
|
||||
SummarizerConnectionError,
|
||||
SummarizerRateLimitError,
|
||||
SummarizerServerError,
|
||||
SummarizerTimeoutError,
|
||||
)
|
||||
from .models import (
|
||||
CacheEntry,
|
||||
CacheStats,
|
||||
DeleteCacheRequest,
|
||||
ErrorResponse,
|
||||
HealthResponse,
|
||||
LlmInfo,
|
||||
LlmProvider,
|
||||
OcrSummarizeResponse,
|
||||
SuccessResponse,
|
||||
SummarizationStyle,
|
||||
SummarizeRequest,
|
||||
SummarizeResponse,
|
||||
)
|
||||
|
||||
DEFAULT_BASE_URL = "http://127.0.0.1:3001/api"
|
||||
DEFAULT_TIMEOUT = 60.0
|
||||
|
||||
|
||||
class SummarizerClient:
|
||||
"""Async client for the Summarizer API.
|
||||
|
||||
Usage:
|
||||
async with SummarizerClient() as client:
|
||||
response = await client.summarize(
|
||||
text="Long text to summarize...",
|
||||
provider=LlmProvider.ANTHROPIC,
|
||||
guild_id="123456789",
|
||||
user_id="987654321",
|
||||
)
|
||||
print(response.summary)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str = DEFAULT_BASE_URL,
|
||||
timeout: float = DEFAULT_TIMEOUT,
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
"""Initialize the client.
|
||||
|
||||
Args:
|
||||
base_url: Base URL for the API (default: http://127.0.0.1:3001/api)
|
||||
timeout: Request timeout in seconds (default: 60.0)
|
||||
headers: Additional headers to include in requests
|
||||
"""
|
||||
self._base_url = base_url.rstrip("/")
|
||||
self._timeout = timeout
|
||||
self._headers = headers or {}
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
"""Enter async context manager and create HTTP session."""
|
||||
self._client = httpx.AsyncClient(
|
||||
base_url=self._base_url,
|
||||
timeout=self._timeout,
|
||||
headers=self._headers,
|
||||
)
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
"""Exit async context manager and close HTTP session."""
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
def _ensure_client(self) -> httpx.AsyncClient:
|
||||
"""Ensure client is initialized."""
|
||||
if self._client is None:
|
||||
msg = (
|
||||
"Client not initialized. Use 'async with SummarizerClient() as client:'"
|
||||
)
|
||||
raise RuntimeError(msg)
|
||||
return self._client
|
||||
|
||||
async def _handle_response(self, response: httpx.Response) -> httpx.Response:
|
||||
"""Handle HTTP response and raise appropriate exceptions."""
|
||||
if response.status_code == 200:
|
||||
return response
|
||||
|
||||
try:
|
||||
error_data = response.json()
|
||||
error_response = ErrorResponse.model_validate(error_data)
|
||||
error_message = error_response.error
|
||||
except Exception: # noqa: BLE001
|
||||
error_message = response.text or f"HTTP {response.status_code}"
|
||||
|
||||
if response.status_code == 400:
|
||||
raise SummarizerBadRequestError(error_message, response)
|
||||
if response.status_code == 429:
|
||||
raise SummarizerRateLimitError(error_message)
|
||||
if response.status_code >= 500:
|
||||
raise SummarizerServerError(error_message, response.status_code)
|
||||
|
||||
raise SummarizerAPIError(error_message, response.status_code)
|
||||
|
||||
async def _request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
*,
|
||||
json: dict[str, Any] | None = None,
|
||||
params: dict[str, str | int | float] | None = None,
|
||||
content: bytes | None = None,
|
||||
content_type: str | None = None,
|
||||
) -> httpx.Response:
|
||||
"""Make an HTTP request."""
|
||||
client = self._ensure_client()
|
||||
|
||||
headers: dict[str, str] = {}
|
||||
if content_type:
|
||||
headers["Content-Type"] = content_type
|
||||
|
||||
try:
|
||||
response = await client.request(
|
||||
method,
|
||||
path,
|
||||
json=json,
|
||||
params=params,
|
||||
content=content,
|
||||
headers=headers if headers else None,
|
||||
)
|
||||
return await self._handle_response(response)
|
||||
except httpx.ConnectError as e:
|
||||
raise SummarizerConnectionError(str(e)) from e
|
||||
except httpx.TimeoutException as e:
|
||||
raise SummarizerTimeoutError(str(e)) from e
|
||||
except httpx.HTTPError as e:
|
||||
raise SummarizerAPIError(str(e)) from e
|
||||
|
||||
async def summarize(
|
||||
self,
|
||||
text: str,
|
||||
provider: LlmProvider,
|
||||
guild_id: str,
|
||||
user_id: str,
|
||||
*,
|
||||
model: str | None = None,
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
top_p: float | None = None,
|
||||
style: SummarizationStyle | None = None,
|
||||
system_prompt: str | None = None,
|
||||
channel_id: str | None = None,
|
||||
) -> SummarizeResponse:
|
||||
"""Summarize text using the specified LLM provider.
|
||||
|
||||
Args:
|
||||
text: The text to summarize
|
||||
provider: LLM provider (anthropic, ollama, or lmstudio)
|
||||
guild_id: Discord guild ID
|
||||
user_id: Discord user ID
|
||||
model: Model name (optional, uses default from config)
|
||||
temperature: Temperature (0.0 to 2.0)
|
||||
max_tokens: Max tokens for the response
|
||||
top_p: Top P sampling (0.0 to 1.0)
|
||||
style: Predefined summarization style
|
||||
system_prompt: Custom system prompt (ignored if style is set)
|
||||
channel_id: Discord channel ID
|
||||
|
||||
Returns:
|
||||
SummarizeResponse with the generated summary
|
||||
"""
|
||||
request = SummarizeRequest(
|
||||
text=text,
|
||||
provider=provider,
|
||||
guild_id=guild_id,
|
||||
user_id=user_id,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
style=style,
|
||||
system_prompt=system_prompt,
|
||||
channel_id=channel_id,
|
||||
)
|
||||
|
||||
response = await self._request(
|
||||
"POST",
|
||||
"/summarize",
|
||||
json=request.model_dump(exclude_none=True),
|
||||
)
|
||||
return SummarizeResponse.model_validate(response.json())
|
||||
|
||||
async def ocr_summarize(
|
||||
self,
|
||||
image_data: bytes,
|
||||
provider: LlmProvider,
|
||||
guild_id: str,
|
||||
user_id: str,
|
||||
*,
|
||||
model: str | None = None,
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
top_p: float | None = None,
|
||||
style: SummarizationStyle | None = None,
|
||||
channel_id: str | None = None,
|
||||
) -> OcrSummarizeResponse:
|
||||
"""OCR and summarize an image.
|
||||
|
||||
Args:
|
||||
image_data: Raw image bytes
|
||||
provider: LLM provider
|
||||
guild_id: Discord guild ID
|
||||
user_id: Discord user ID
|
||||
model: Model name (optional)
|
||||
temperature: Temperature (0.0 to 2.0)
|
||||
max_tokens: Max tokens for the response
|
||||
top_p: Top P sampling (0.0 to 1.0)
|
||||
style: Predefined summarization style
|
||||
channel_id: Discord channel ID
|
||||
|
||||
Returns:
|
||||
OcrSummarizeResponse with extracted text and summary
|
||||
"""
|
||||
params: dict[str, str | int | float] = {
|
||||
"provider": provider.value,
|
||||
"guild_id": guild_id,
|
||||
"user_id": user_id,
|
||||
}
|
||||
if model is not None:
|
||||
params["model"] = model
|
||||
if temperature is not None:
|
||||
params["temperature"] = temperature
|
||||
if max_tokens is not None:
|
||||
params["max_tokens"] = max_tokens
|
||||
if top_p is not None:
|
||||
params["top_p"] = top_p
|
||||
if style is not None:
|
||||
params["style"] = style.value
|
||||
if channel_id is not None:
|
||||
params["channel_id"] = channel_id
|
||||
|
||||
response = await self._request(
|
||||
"POST",
|
||||
"/ocr-summarize",
|
||||
params=params,
|
||||
content=image_data,
|
||||
content_type="application/octet-stream",
|
||||
)
|
||||
return OcrSummarizeResponse.model_validate(response.json())
|
||||
|
||||
async def list_models(self) -> list[LlmInfo]:
|
||||
"""List available LLM models.
|
||||
|
||||
Returns:
|
||||
List of available LLM models
|
||||
"""
|
||||
response = await self._request("GET", "/models")
|
||||
data = response.json()
|
||||
return [LlmInfo.model_validate(item) for item in data]
|
||||
|
||||
async def get_cache_stats(self) -> CacheStats:
|
||||
"""Get cache statistics.
|
||||
|
||||
Returns:
|
||||
CacheStats with cache metrics
|
||||
"""
|
||||
response = await self._request("GET", "/cache/stats")
|
||||
return CacheStats.model_validate(response.json())
|
||||
|
||||
async def list_cache_entries(
|
||||
self,
|
||||
*,
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
) -> list[CacheEntry]:
|
||||
"""List cache entries with pagination.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of entries to return
|
||||
offset: Number of entries to skip
|
||||
|
||||
Returns:
|
||||
List of cache entries
|
||||
"""
|
||||
params: dict[str, str | int | float] = {}
|
||||
if limit is not None:
|
||||
params["limit"] = limit
|
||||
if offset is not None:
|
||||
params["offset"] = offset
|
||||
|
||||
response = await self._request("GET", "/cache/entries", params=params)
|
||||
data = response.json()
|
||||
return [CacheEntry.model_validate(item) for item in data]
|
||||
|
||||
async def get_guild_cache(self, guild_id: str) -> list[CacheEntry]:
|
||||
"""Get cache entries for a specific guild.
|
||||
|
||||
Args:
|
||||
guild_id: Discord guild ID
|
||||
|
||||
Returns:
|
||||
List of cache entries for the guild
|
||||
"""
|
||||
response = await self._request("GET", f"/cache/guild/{guild_id}")
|
||||
data = response.json()
|
||||
return [CacheEntry.model_validate(item) for item in data]
|
||||
|
||||
async def get_user_cache(self, user_id: str) -> list[CacheEntry]:
|
||||
"""Get cache entries for a specific user.
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID
|
||||
|
||||
Returns:
|
||||
List of cache entries for the user
|
||||
"""
|
||||
response = await self._request("GET", f"/cache/user/{user_id}")
|
||||
data = response.json()
|
||||
return [CacheEntry.model_validate(item) for item in data]
|
||||
|
||||
async def delete_cache(
|
||||
self,
|
||||
*,
|
||||
entry_id: int | None = None,
|
||||
guild_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
delete_all: bool = False,
|
||||
) -> SuccessResponse:
|
||||
"""Delete cache entries.
|
||||
|
||||
Args:
|
||||
entry_id: Specific cache entry ID to delete
|
||||
guild_id: Delete all entries for a guild
|
||||
user_id: Delete all entries for a user
|
||||
delete_all: Delete all entries (use with caution)
|
||||
|
||||
Returns:
|
||||
SuccessResponse confirming deletion
|
||||
"""
|
||||
request = DeleteCacheRequest(
|
||||
id=entry_id,
|
||||
guild_id=guild_id,
|
||||
user_id=user_id,
|
||||
delete_all=delete_all if delete_all else None,
|
||||
)
|
||||
|
||||
response = await self._request(
|
||||
"POST",
|
||||
"/cache/delete",
|
||||
json=request.model_dump(exclude_none=True),
|
||||
)
|
||||
return SuccessResponse.model_validate(response.json())
|
||||
|
||||
async def health_check(self) -> HealthResponse:
|
||||
"""Check API health.
|
||||
|
||||
Returns:
|
||||
HealthResponse with status and version
|
||||
"""
|
||||
response = await self._request("GET", "/health")
|
||||
return HealthResponse.model_validate(response.json())
|
||||
Reference in New Issue
Block a user