(Feat): Initial Commit
This commit is contained in:
294
discord_bot/src/views.py
Normal file
294
discord_bot/src/views.py
Normal file
@@ -0,0 +1,294 @@
|
||||
"""Discord UI components for summarization commands."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable, Coroutine
|
||||
|
||||
import disnake
|
||||
from loguru import logger
|
||||
|
||||
from sdk import LlmInfo, SummarizerClient
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.bot import SummarizerBot
|
||||
|
||||
|
||||
class ProviderSelect(disnake.ui.StringSelect):
|
||||
"""Dropdown for selecting LLM provider."""
|
||||
|
||||
def __init__(self, selected: str = "lmstudio") -> None:
|
||||
options = [
|
||||
disnake.SelectOption(
|
||||
label="Anthropic (Claude)",
|
||||
value="anthropic",
|
||||
description="Claude AI models",
|
||||
emoji="🤖",
|
||||
default=(selected == "anthropic"),
|
||||
),
|
||||
disnake.SelectOption(
|
||||
label="Ollama",
|
||||
value="ollama",
|
||||
description="Local Ollama models",
|
||||
emoji="🦙",
|
||||
default=(selected == "ollama"),
|
||||
),
|
||||
disnake.SelectOption(
|
||||
label="LM Studio",
|
||||
value="lmstudio",
|
||||
description="Local LM Studio models",
|
||||
emoji="💻",
|
||||
default=(selected == "lmstudio"),
|
||||
),
|
||||
]
|
||||
super().__init__(
|
||||
placeholder="Select AI Provider",
|
||||
options=options,
|
||||
custom_id="provider_select",
|
||||
)
|
||||
|
||||
async def callback(self, inter: disnake.MessageInteraction) -> None:
|
||||
"""Handle provider selection - triggers model list refresh."""
|
||||
# The view will handle this
|
||||
self.view.selected_provider = self.values[0] # type: ignore[union-attr]
|
||||
await self.view.refresh_models(inter) # type: ignore[union-attr]
|
||||
|
||||
|
||||
class ModelSelect(disnake.ui.StringSelect):
|
||||
"""Dropdown for selecting LLM model."""
|
||||
|
||||
def __init__(self, models: list[LlmInfo] | None = None) -> None:
|
||||
if models:
|
||||
options = [
|
||||
disnake.SelectOption(
|
||||
label=m.model[:100],
|
||||
value=m.model,
|
||||
description=f"{m.provider}"
|
||||
+ (f" v{m.version}" if m.version else ""), # noqa: E501
|
||||
)
|
||||
for m in models[:25] # Discord limit
|
||||
]
|
||||
else:
|
||||
options = [
|
||||
disnake.SelectOption(
|
||||
label="Select provider first",
|
||||
value="none",
|
||||
description="Choose a provider to see available models",
|
||||
)
|
||||
]
|
||||
super().__init__(
|
||||
placeholder="Select Model (optional)",
|
||||
options=options,
|
||||
custom_id="model_select",
|
||||
disabled=not models,
|
||||
)
|
||||
|
||||
async def callback(self, inter: disnake.MessageInteraction) -> None:
|
||||
"""Handle model selection."""
|
||||
self.view.selected_model = self.values[0] if self.values[0] != "none" else None # type: ignore[union-attr]
|
||||
await inter.response.defer()
|
||||
|
||||
|
||||
class ProviderModelView(disnake.ui.View):
|
||||
"""View for selecting provider and model before summarization."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bot: SummarizerBot,
|
||||
on_confirm: Callable[[str, str | None], Coroutine[Any, Any, None]],
|
||||
*,
|
||||
timeout: float = 120.0,
|
||||
author_id: int,
|
||||
) -> None:
|
||||
super().__init__(timeout=timeout)
|
||||
self.bot = bot
|
||||
self.on_confirm = on_confirm
|
||||
self.author_id = author_id
|
||||
self.selected_provider: str = "lmstudio"
|
||||
self.selected_model: str | None = None
|
||||
self._models_cache: dict[str, list[LlmInfo]] = {}
|
||||
|
||||
# Add components
|
||||
self.provider_select = ProviderSelect()
|
||||
self.model_select = ModelSelect()
|
||||
self.add_item(self.provider_select)
|
||||
self.add_item(self.model_select)
|
||||
|
||||
async def interaction_check( # type: ignore[override]
|
||||
self, interaction: disnake.MessageInteraction
|
||||
) -> bool:
|
||||
"""Only allow the command author to interact."""
|
||||
if interaction.author.id != self.author_id:
|
||||
await interaction.response.send_message(
|
||||
"Only the command author can use these controls.", ephemeral=True
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
async def fetch_models(self) -> list[LlmInfo]:
|
||||
"""Fetch available models from API."""
|
||||
try:
|
||||
async with SummarizerClient(base_url=self.bot.config.bot.api_url) as client:
|
||||
return await client.list_models()
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.warning(f"Failed to fetch models: {e}")
|
||||
return []
|
||||
|
||||
async def prefetch_models(self) -> None:
|
||||
"""Prefetch and cache all models."""
|
||||
if not self._models_cache:
|
||||
all_models = await self.fetch_models()
|
||||
for model in all_models:
|
||||
provider = model.provider.lower()
|
||||
if provider not in self._models_cache:
|
||||
self._models_cache[provider] = []
|
||||
self._models_cache[provider].append(model)
|
||||
|
||||
def get_provider_models(self, provider: str) -> list[LlmInfo]:
|
||||
"""Get available models for a specific provider."""
|
||||
return [m for m in self._models_cache.get(provider, []) if m.available]
|
||||
|
||||
async def refresh_models(self, inter: disnake.MessageInteraction) -> None:
|
||||
"""Refresh model list based on selected provider."""
|
||||
await inter.response.defer()
|
||||
|
||||
# Fetch models if not cached
|
||||
if not self._models_cache:
|
||||
all_models = await self.fetch_models()
|
||||
for model in all_models:
|
||||
provider = model.provider.lower()
|
||||
if provider not in self._models_cache:
|
||||
self._models_cache[provider] = []
|
||||
self._models_cache[provider].append(model)
|
||||
|
||||
# Filter models for selected provider
|
||||
provider_models = self._models_cache.get(self.selected_provider, [])
|
||||
available_models = [m for m in provider_models if m.available]
|
||||
|
||||
# Update provider select to show current selection
|
||||
self.remove_item(self.provider_select)
|
||||
self.provider_select = ProviderSelect(selected=self.selected_provider)
|
||||
self.add_item(self.provider_select)
|
||||
|
||||
# Update model select
|
||||
self.remove_item(self.model_select)
|
||||
self.model_select = ModelSelect(available_models if available_models else None)
|
||||
self.add_item(self.model_select)
|
||||
self.selected_model = None
|
||||
|
||||
await inter.edit_original_response(view=self)
|
||||
|
||||
@disnake.ui.button(label="Use Defaults", style=disnake.ButtonStyle.secondary, row=2)
|
||||
async def use_defaults(
|
||||
self, _button: disnake.ui.Button, _inter: disnake.MessageInteraction # type: ignore[type-arg]
|
||||
) -> None:
|
||||
"""Use default provider and model."""
|
||||
self.stop()
|
||||
await self.on_confirm("lmstudio", None)
|
||||
|
||||
@disnake.ui.button(label="Confirm", style=disnake.ButtonStyle.primary, row=2)
|
||||
async def confirm(
|
||||
self, _button: disnake.ui.Button, _inter: disnake.MessageInteraction # type: ignore[type-arg]
|
||||
) -> None:
|
||||
"""Confirm selection and proceed."""
|
||||
self.stop()
|
||||
await self.on_confirm(self.selected_provider, self.selected_model)
|
||||
|
||||
@disnake.ui.button(label="Cancel", style=disnake.ButtonStyle.danger, row=2)
|
||||
async def cancel(
|
||||
self, _button: disnake.ui.Button, inter: disnake.MessageInteraction # type: ignore[type-arg]
|
||||
) -> None:
|
||||
"""Cancel the operation."""
|
||||
self.stop()
|
||||
await inter.response.edit_message(
|
||||
content="Operation cancelled.", embed=None, view=None
|
||||
)
|
||||
|
||||
|
||||
class QuickSummarizeView(disnake.ui.View):
|
||||
"""View with quick summarize button and advanced options."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bot: SummarizerBot,
|
||||
summarize_callback: Callable[[str, str | None], Coroutine[Any, Any, None]],
|
||||
*,
|
||||
timeout: float = 60.0,
|
||||
author_id: int,
|
||||
) -> None:
|
||||
super().__init__(timeout=timeout)
|
||||
self.bot = bot
|
||||
self.summarize_callback = summarize_callback
|
||||
self.author_id = author_id
|
||||
|
||||
async def interaction_check( # type: ignore[override]
|
||||
self, interaction: disnake.MessageInteraction
|
||||
) -> bool:
|
||||
"""Only allow the command author to interact."""
|
||||
if interaction.author.id != self.author_id:
|
||||
await interaction.response.send_message(
|
||||
"Only the command author can use these controls.", ephemeral=True
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
@disnake.ui.button(
|
||||
label="Summarize (Default)", style=disnake.ButtonStyle.primary, emoji="⚡"
|
||||
)
|
||||
async def quick_summarize(
|
||||
self, _button: disnake.ui.Button, inter: disnake.MessageInteraction # type: ignore[type-arg]
|
||||
) -> None:
|
||||
"""Quick summarize with defaults."""
|
||||
self.stop()
|
||||
await inter.response.defer()
|
||||
await self.summarize_callback("lmstudio", None)
|
||||
|
||||
@disnake.ui.button(
|
||||
label="Choose Provider/Model", style=disnake.ButtonStyle.secondary, emoji="⚙️"
|
||||
)
|
||||
async def advanced_options(
|
||||
self, _button: disnake.ui.Button, inter: disnake.MessageInteraction # type: ignore[type-arg]
|
||||
) -> None:
|
||||
"""Show advanced provider/model selection."""
|
||||
self.stop()
|
||||
|
||||
async def on_confirm(provider: str, model: str | None) -> None:
|
||||
await self.summarize_callback(provider, model)
|
||||
|
||||
view = ProviderModelView(self.bot, on_confirm, author_id=self.author_id)
|
||||
|
||||
# Prefetch models
|
||||
await view.prefetch_models()
|
||||
|
||||
# Set up initial model select with lmstudio models
|
||||
lmstudio_models = view.get_provider_models("lmstudio")
|
||||
view.remove_item(view.model_select)
|
||||
view.model_select = ModelSelect(lmstudio_models if lmstudio_models else None)
|
||||
view.add_item(view.model_select)
|
||||
|
||||
embed = disnake.Embed(
|
||||
title="🔧 Advanced Options",
|
||||
description="Select your preferred AI provider and model.",
|
||||
color=disnake.Color.blurple(),
|
||||
)
|
||||
embed.add_field(
|
||||
name="Provider",
|
||||
value="Choose the AI service to use for summarization.",
|
||||
inline=False,
|
||||
)
|
||||
embed.add_field(
|
||||
name="Model",
|
||||
value="Optionally select a specific model (leave empty for default).",
|
||||
inline=False,
|
||||
)
|
||||
|
||||
await inter.response.edit_message(embed=embed, view=view)
|
||||
|
||||
@disnake.ui.button(label="Cancel", style=disnake.ButtonStyle.danger, emoji="✖️")
|
||||
async def cancel(
|
||||
self, _button: disnake.ui.Button, inter: disnake.MessageInteraction # type: ignore[type-arg]
|
||||
) -> None:
|
||||
"""Cancel the operation."""
|
||||
self.stop()
|
||||
await inter.response.edit_message(
|
||||
content="Operation cancelled.", embed=None, view=None
|
||||
)
|
||||
Reference in New Issue
Block a user