(Feat): Initial Commit

This commit is contained in:
2025-12-12 15:31:27 +00:00
commit 5b13236b80
48 changed files with 13146 additions and 0 deletions

View File

@@ -0,0 +1,69 @@
"""Summarizer API Python SDK.
A fully typed async client for the Summarizer API.
Usage:
from sdk import SummarizerClient, LlmProvider
async with SummarizerClient() as client:
response = await client.summarize(
text="Long text to summarize...",
provider=LlmProvider.ANTHROPIC,
guild_id="123456789",
user_id="987654321",
)
print(response.summary)
"""
from .client import SummarizerClient
from .exceptions import (
SummarizerAPIError,
SummarizerBadRequestError,
SummarizerConnectionError,
SummarizerRateLimitError,
SummarizerServerError,
SummarizerTimeoutError,
)
from .models import (
CacheEntry,
CacheStats,
DeleteCacheRequest,
ErrorResponse,
ImageInfo,
LlmInfo,
LlmProvider,
OcrSummarizeResponse,
SuccessResponse,
SummarizationStyle,
SummarizeRequest,
SummarizeResponse,
TokenUsage,
)
__all__ = [
# Client
"SummarizerClient",
# Exceptions
"SummarizerAPIError",
"SummarizerBadRequestError",
"SummarizerConnectionError",
"SummarizerRateLimitError",
"SummarizerServerError",
"SummarizerTimeoutError",
# Enums
"LlmProvider",
"SummarizationStyle",
# Request models
"SummarizeRequest",
"DeleteCacheRequest",
# Response models
"SummarizeResponse",
"OcrSummarizeResponse",
"CacheEntry",
"CacheStats",
"LlmInfo",
"SuccessResponse",
"ErrorResponse",
"TokenUsage",
"ImageInfo",
]

374
discord_bot/sdk/client.py Normal file
View File

@@ -0,0 +1,374 @@
"""Async client for the Summarizer API."""
from __future__ import annotations
from types import TracebackType
from typing import Any, Self
import httpx
from .exceptions import (
SummarizerAPIError,
SummarizerBadRequestError,
SummarizerConnectionError,
SummarizerRateLimitError,
SummarizerServerError,
SummarizerTimeoutError,
)
from .models import (
CacheEntry,
CacheStats,
DeleteCacheRequest,
ErrorResponse,
HealthResponse,
LlmInfo,
LlmProvider,
OcrSummarizeResponse,
SuccessResponse,
SummarizationStyle,
SummarizeRequest,
SummarizeResponse,
)
DEFAULT_BASE_URL = "http://127.0.0.1:3001/api"
DEFAULT_TIMEOUT = 60.0
class SummarizerClient:
"""Async client for the Summarizer API.
Usage:
async with SummarizerClient() as client:
response = await client.summarize(
text="Long text to summarize...",
provider=LlmProvider.ANTHROPIC,
guild_id="123456789",
user_id="987654321",
)
print(response.summary)
"""
def __init__(
self,
base_url: str = DEFAULT_BASE_URL,
timeout: float = DEFAULT_TIMEOUT,
headers: dict[str, str] | None = None,
) -> None:
"""Initialize the client.
Args:
base_url: Base URL for the API (default: http://127.0.0.1:3001/api)
timeout: Request timeout in seconds (default: 60.0)
headers: Additional headers to include in requests
"""
self._base_url = base_url.rstrip("/")
self._timeout = timeout
self._headers = headers or {}
self._client: httpx.AsyncClient | None = None
async def __aenter__(self) -> Self:
"""Enter async context manager and create HTTP session."""
self._client = httpx.AsyncClient(
base_url=self._base_url,
timeout=self._timeout,
headers=self._headers,
)
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
"""Exit async context manager and close HTTP session."""
if self._client:
await self._client.aclose()
self._client = None
def _ensure_client(self) -> httpx.AsyncClient:
"""Ensure client is initialized."""
if self._client is None:
msg = (
"Client not initialized. Use 'async with SummarizerClient() as client:'"
)
raise RuntimeError(msg)
return self._client
async def _handle_response(self, response: httpx.Response) -> httpx.Response:
"""Handle HTTP response and raise appropriate exceptions."""
if response.status_code == 200:
return response
try:
error_data = response.json()
error_response = ErrorResponse.model_validate(error_data)
error_message = error_response.error
except Exception: # noqa: BLE001
error_message = response.text or f"HTTP {response.status_code}"
if response.status_code == 400:
raise SummarizerBadRequestError(error_message, response)
if response.status_code == 429:
raise SummarizerRateLimitError(error_message)
if response.status_code >= 500:
raise SummarizerServerError(error_message, response.status_code)
raise SummarizerAPIError(error_message, response.status_code)
async def _request(
self,
method: str,
path: str,
*,
json: dict[str, Any] | None = None,
params: dict[str, str | int | float] | None = None,
content: bytes | None = None,
content_type: str | None = None,
) -> httpx.Response:
"""Make an HTTP request."""
client = self._ensure_client()
headers: dict[str, str] = {}
if content_type:
headers["Content-Type"] = content_type
try:
response = await client.request(
method,
path,
json=json,
params=params,
content=content,
headers=headers if headers else None,
)
return await self._handle_response(response)
except httpx.ConnectError as e:
raise SummarizerConnectionError(str(e)) from e
except httpx.TimeoutException as e:
raise SummarizerTimeoutError(str(e)) from e
except httpx.HTTPError as e:
raise SummarizerAPIError(str(e)) from e
async def summarize(
self,
text: str,
provider: LlmProvider,
guild_id: str,
user_id: str,
*,
model: str | None = None,
temperature: float | None = None,
max_tokens: int | None = None,
top_p: float | None = None,
style: SummarizationStyle | None = None,
system_prompt: str | None = None,
channel_id: str | None = None,
) -> SummarizeResponse:
"""Summarize text using the specified LLM provider.
Args:
text: The text to summarize
provider: LLM provider (anthropic, ollama, or lmstudio)
guild_id: Discord guild ID
user_id: Discord user ID
model: Model name (optional, uses default from config)
temperature: Temperature (0.0 to 2.0)
max_tokens: Max tokens for the response
top_p: Top P sampling (0.0 to 1.0)
style: Predefined summarization style
system_prompt: Custom system prompt (ignored if style is set)
channel_id: Discord channel ID
Returns:
SummarizeResponse with the generated summary
"""
request = SummarizeRequest(
text=text,
provider=provider,
guild_id=guild_id,
user_id=user_id,
model=model,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
style=style,
system_prompt=system_prompt,
channel_id=channel_id,
)
response = await self._request(
"POST",
"/summarize",
json=request.model_dump(exclude_none=True),
)
return SummarizeResponse.model_validate(response.json())
async def ocr_summarize(
self,
image_data: bytes,
provider: LlmProvider,
guild_id: str,
user_id: str,
*,
model: str | None = None,
temperature: float | None = None,
max_tokens: int | None = None,
top_p: float | None = None,
style: SummarizationStyle | None = None,
channel_id: str | None = None,
) -> OcrSummarizeResponse:
"""OCR and summarize an image.
Args:
image_data: Raw image bytes
provider: LLM provider
guild_id: Discord guild ID
user_id: Discord user ID
model: Model name (optional)
temperature: Temperature (0.0 to 2.0)
max_tokens: Max tokens for the response
top_p: Top P sampling (0.0 to 1.0)
style: Predefined summarization style
channel_id: Discord channel ID
Returns:
OcrSummarizeResponse with extracted text and summary
"""
params: dict[str, str | int | float] = {
"provider": provider.value,
"guild_id": guild_id,
"user_id": user_id,
}
if model is not None:
params["model"] = model
if temperature is not None:
params["temperature"] = temperature
if max_tokens is not None:
params["max_tokens"] = max_tokens
if top_p is not None:
params["top_p"] = top_p
if style is not None:
params["style"] = style.value
if channel_id is not None:
params["channel_id"] = channel_id
response = await self._request(
"POST",
"/ocr-summarize",
params=params,
content=image_data,
content_type="application/octet-stream",
)
return OcrSummarizeResponse.model_validate(response.json())
async def list_models(self) -> list[LlmInfo]:
"""List available LLM models.
Returns:
List of available LLM models
"""
response = await self._request("GET", "/models")
data = response.json()
return [LlmInfo.model_validate(item) for item in data]
async def get_cache_stats(self) -> CacheStats:
"""Get cache statistics.
Returns:
CacheStats with cache metrics
"""
response = await self._request("GET", "/cache/stats")
return CacheStats.model_validate(response.json())
async def list_cache_entries(
self,
*,
limit: int | None = None,
offset: int | None = None,
) -> list[CacheEntry]:
"""List cache entries with pagination.
Args:
limit: Maximum number of entries to return
offset: Number of entries to skip
Returns:
List of cache entries
"""
params: dict[str, str | int | float] = {}
if limit is not None:
params["limit"] = limit
if offset is not None:
params["offset"] = offset
response = await self._request("GET", "/cache/entries", params=params)
data = response.json()
return [CacheEntry.model_validate(item) for item in data]
async def get_guild_cache(self, guild_id: str) -> list[CacheEntry]:
"""Get cache entries for a specific guild.
Args:
guild_id: Discord guild ID
Returns:
List of cache entries for the guild
"""
response = await self._request("GET", f"/cache/guild/{guild_id}")
data = response.json()
return [CacheEntry.model_validate(item) for item in data]
async def get_user_cache(self, user_id: str) -> list[CacheEntry]:
"""Get cache entries for a specific user.
Args:
user_id: Discord user ID
Returns:
List of cache entries for the user
"""
response = await self._request("GET", f"/cache/user/{user_id}")
data = response.json()
return [CacheEntry.model_validate(item) for item in data]
async def delete_cache(
self,
*,
entry_id: int | None = None,
guild_id: str | None = None,
user_id: str | None = None,
delete_all: bool = False,
) -> SuccessResponse:
"""Delete cache entries.
Args:
entry_id: Specific cache entry ID to delete
guild_id: Delete all entries for a guild
user_id: Delete all entries for a user
delete_all: Delete all entries (use with caution)
Returns:
SuccessResponse confirming deletion
"""
request = DeleteCacheRequest(
id=entry_id,
guild_id=guild_id,
user_id=user_id,
delete_all=delete_all if delete_all else None,
)
response = await self._request(
"POST",
"/cache/delete",
json=request.model_dump(exclude_none=True),
)
return SuccessResponse.model_validate(response.json())
async def health_check(self) -> HealthResponse:
"""Check API health.
Returns:
HealthResponse with status and version
"""
response = await self._request("GET", "/health")
return HealthResponse.model_validate(response.json())

View File

@@ -0,0 +1,57 @@
"""Exception classes for Summarizer API errors."""
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from httpx import Response
class SummarizerAPIError(Exception):
"""Base exception for Summarizer API errors."""
def __init__(self, message: str, status_code: int | None = None) -> None:
super().__init__(message)
self.message = message
self.status_code = status_code
def __str__(self) -> str:
if self.status_code:
return f"[{self.status_code}] {self.message}"
return self.message
class SummarizerConnectionError(SummarizerAPIError):
"""Raised when connection to the API fails."""
def __init__(self, message: str = "Failed to connect to Summarizer API") -> None:
super().__init__(message)
class SummarizerTimeoutError(SummarizerAPIError):
"""Raised when API request times out."""
def __init__(self, message: str = "Request to Summarizer API timed out") -> None:
super().__init__(message)
class SummarizerBadRequestError(SummarizerAPIError):
"""Raised for 400 Bad Request errors."""
def __init__(self, message: str, response: Response | None = None) -> None:
super().__init__(message, status_code=400)
self.response = response
class SummarizerServerError(SummarizerAPIError):
"""Raised for 5xx server errors."""
def __init__(self, message: str, status_code: int = 500) -> None:
super().__init__(message, status_code=status_code)
class SummarizerRateLimitError(SummarizerAPIError):
"""Raised when rate limit is exceeded."""
def __init__(self, message: str = "Rate limit exceeded") -> None:
super().__init__(message, status_code=429)

213
discord_bot/sdk/models.py Normal file
View File

@@ -0,0 +1,213 @@
"""Models for Summarizer API."""
from __future__ import annotations
from enum import StrEnum
from typing import Annotated
from pydantic import BaseModel, Field
class LlmProvider(StrEnum):
"""LLM provider options."""
ANTHROPIC = "anthropic"
OLLAMA = "ollama"
LMSTUDIO = "lmstudio"
class SummarizationStyle(StrEnum):
"""Predefined summarization styles."""
BRIEF = "brief"
DETAILED = "detailed"
FUNNY = "funny"
PROFESSIONAL = "professional"
TECHNICAL = "technical"
ELI5 = "eli5"
BULLETS = "bullets"
ACADEMIC = "academic"
ROAST = "roast"
class TokenUsage(BaseModel):
"""Token usage information from LLM response."""
input_tokens: Annotated[int, Field(description="Input tokens (prompt)")]
output_tokens: Annotated[int, Field(description="Output tokens (completion)")]
total_tokens: Annotated[int, Field(description="Total tokens used")]
estimated_cost_usd: Annotated[
float | None, Field(default=None, description="Estimated cost in USD")
]
class ImageInfo(BaseModel):
"""Image metadata from OCR processing."""
width: Annotated[int, Field(description="Image width in pixels")]
height: Annotated[int, Field(description="Image height in pixels")]
format: Annotated[str, Field(description="Image format (e.g., PNG, JPEG)")]
size_bytes: Annotated[int, Field(description="File size in bytes")]
class SummarizeRequest(BaseModel):
"""Request payload for text summarization."""
text: Annotated[
str, Field(min_length=1, max_length=100000, description="Text to summarize")
]
provider: Annotated[LlmProvider, Field(description="LLM provider")]
guild_id: Annotated[
str, Field(min_length=1, max_length=50, description="Discord guild ID")
]
user_id: Annotated[
str, Field(min_length=1, max_length=50, description="Discord user ID")
]
model: Annotated[
str | None, Field(default=None, max_length=100, description="Model name")
]
temperature: Annotated[
float | None,
Field(default=None, ge=0.0, le=2.0, description="Temperature (0.0 to 2.0)"),
]
max_tokens: Annotated[
int | None,
Field(default=None, ge=1, le=100000, description="Max tokens for response"),
]
top_p: Annotated[
float | None,
Field(default=None, ge=0.0, le=1.0, description="Top P sampling (0.0 to 1.0)"),
]
style: Annotated[
SummarizationStyle | None,
Field(default=None, description="Summarization style"),
]
system_prompt: Annotated[
str | None,
Field(default=None, max_length=5000, description="Custom system prompt"),
]
channel_id: Annotated[
str | None, Field(default=None, max_length=50, description="Discord channel ID")
]
class SummarizeResponse(BaseModel):
"""Response from text summarization endpoint."""
summary: Annotated[str, Field(description="Generated summary")]
model: Annotated[str, Field(description="Model used")]
provider: Annotated[str, Field(description="Provider used")]
from_cache: Annotated[bool, Field(description="Whether served from cache")]
checksum: Annotated[str, Field(description="Request checksum")]
timestamp: Annotated[str, Field(description="Timestamp")]
token_usage: Annotated[
TokenUsage | None, Field(default=None, description="Token usage info")
]
processing_time_ms: Annotated[
int | None, Field(default=None, description="Processing time in milliseconds")
]
style_used: Annotated[
str | None, Field(default=None, description="Summarization style used")
]
class OcrSummarizeResponse(BaseModel):
"""Response from OCR summarization endpoint."""
extracted_text: Annotated[str, Field(description="Extracted text from OCR")]
summary: Annotated[str, Field(description="Generated summary")]
model: Annotated[str, Field(description="Model used")]
provider: Annotated[str, Field(description="Provider used")]
from_cache: Annotated[bool, Field(description="Whether served from cache")]
checksum: Annotated[str, Field(description="Request checksum")]
timestamp: Annotated[str, Field(description="Timestamp")]
token_usage: Annotated[
TokenUsage | None, Field(default=None, description="Token usage info")
]
ocr_time_ms: Annotated[
int, Field(description="OCR processing time in milliseconds")
]
summarization_time_ms: Annotated[
int | None,
Field(default=None, description="Summarization time in milliseconds"),
]
total_time_ms: Annotated[
int, Field(description="Total processing time in milliseconds")
]
style_used: Annotated[
str | None, Field(default=None, description="Summarization style used")
]
image_info: Annotated[ImageInfo, Field(description="Image metadata")]
class LlmInfo(BaseModel):
"""Information about an available LLM model."""
provider: Annotated[str, Field(description="Provider name")]
model: Annotated[str, Field(description="Model name")]
version: Annotated[str | None, Field(default=None, description="Model version")]
available: Annotated[bool, Field(description="Whether model is available")]
class CacheEntry(BaseModel):
"""A single cache entry."""
id: Annotated[int, Field(description="Cache entry ID")]
checksum: Annotated[str, Field(description="Request checksum")]
text: Annotated[str, Field(description="Original text")]
summary: Annotated[str, Field(description="Generated summary")]
provider: Annotated[str, Field(description="Provider used")]
model: Annotated[str, Field(description="Model used")]
guild_id: Annotated[str, Field(description="Guild ID")]
user_id: Annotated[str, Field(description="User ID")]
channel_id: Annotated[str | None, Field(default=None, description="Channel ID")]
created_at: Annotated[str, Field(description="Created timestamp")]
last_accessed: Annotated[str, Field(description="Last accessed timestamp")]
access_count: Annotated[int, Field(description="Access count")]
class CacheStats(BaseModel):
"""Cache statistics."""
total_entries: Annotated[int, Field(description="Total cache entries")]
total_hits: Annotated[int, Field(description="Total cache hits")]
total_misses: Annotated[int, Field(description="Total cache misses")]
hit_rate: Annotated[float, Field(description="Cache hit rate (percentage)")]
total_size_bytes: Annotated[int, Field(description="Total size in bytes")]
class DeleteCacheRequest(BaseModel):
"""Request payload for deleting cache entries."""
id: Annotated[
int | None, Field(default=None, description="Cache entry ID to delete")
]
guild_id: Annotated[
str | None, Field(default=None, description="Delete all for guild")
]
user_id: Annotated[
str | None, Field(default=None, description="Delete all for user")
]
delete_all: Annotated[
bool | None, Field(default=None, description="Delete all entries")
]
class HealthResponse(BaseModel):
"""Health check response."""
status: Annotated[str, Field(description="Health status (e.g., 'healthy')")]
version: Annotated[str, Field(description="API version")]
class SuccessResponse(BaseModel):
"""Generic success response."""
success: Annotated[bool, Field(description="Operation success status")]
message: Annotated[str, Field(description="Success message")]
class ErrorResponse(BaseModel):
"""Error response from API."""
success: Annotated[bool, Field(description="Always False for errors")]
error: Annotated[str, Field(description="Error message")]