(Feat): Initial Commit
This commit is contained in:
402
backend_api/src/llm.rs
Normal file
402
backend_api/src/llm.rs
Normal file
@@ -0,0 +1,402 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use crate::config::Config;
|
||||
use crate::models::{SummarizeRequest, TokenUsage};
|
||||
use crate::v1::enums::LlmProvider;
|
||||
use crate::claude_models::ClaudeModelType;
|
||||
|
||||
pub struct LlmClient {
|
||||
client: Client,
|
||||
config: Config,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct AnthropicRequest {
|
||||
model: String,
|
||||
max_tokens: u32,
|
||||
temperature: f32,
|
||||
messages: Vec<AnthropicMessage>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
system: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
top_p: Option<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct AnthropicMessage {
|
||||
role: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct AnthropicResponse {
|
||||
content: Vec<AnthropicContent>,
|
||||
model: String,
|
||||
usage: AnthropicUsage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct AnthropicUsage {
|
||||
input_tokens: u32,
|
||||
output_tokens: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct AnthropicContent {
|
||||
text: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct OllamaRequest {
|
||||
model: String,
|
||||
prompt: String,
|
||||
stream: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
system: Option<String>,
|
||||
options: OllamaOptions,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct OllamaOptions {
|
||||
temperature: f32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
top_p: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
num_predict: Option<i32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct OllamaResponse {
|
||||
response: String,
|
||||
model: String,
|
||||
prompt_eval_count: Option<u32>,
|
||||
eval_count: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct _OllamaModelInfo {
|
||||
pub name: String,
|
||||
pub modified_at: String,
|
||||
pub size: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct _OllamaListResponse {
|
||||
models: Vec<_OllamaModelInfo>,
|
||||
}
|
||||
|
||||
impl LlmClient {
|
||||
pub fn new(config: Config) -> Self {
|
||||
Self {
|
||||
client: Client::new(),
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn summarize(&self, request: &SummarizeRequest) -> Result<(String, String, TokenUsage)> {
|
||||
match request.provider {
|
||||
LlmProvider::Anthropic => self.summarize_anthropic(request).await,
|
||||
LlmProvider::Ollama => self.summarize_ollama(request).await,
|
||||
LlmProvider::Lmstudio => self.summarize_lmstudio(request).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn summarize_anthropic(&self, request: &SummarizeRequest) -> Result<(String, String, TokenUsage)> {
|
||||
let anthropic_config = self.config.anthropic.as_ref()
|
||||
.ok_or_else(|| anyhow!("Anthropic config not found"))?;
|
||||
|
||||
let model = request.model.clone()
|
||||
.unwrap_or_else(|| anthropic_config.model.clone());
|
||||
|
||||
let max_tokens = request.max_tokens
|
||||
.unwrap_or(anthropic_config.max_tokens);
|
||||
|
||||
let temperature = request.temperature
|
||||
.unwrap_or(anthropic_config.temperature);
|
||||
|
||||
let system_prompt = if let Some(ref style) = request.style {
|
||||
style.to_system_prompt()
|
||||
} else {
|
||||
request.system_prompt.clone()
|
||||
.unwrap_or_else(|| "You are a helpful assistant that summarizes text concisely and accurately. Provide a clear, well-structured summary of the given text.".to_string())
|
||||
};
|
||||
|
||||
let anthropic_request = AnthropicRequest {
|
||||
model: model.clone(),
|
||||
max_tokens,
|
||||
temperature,
|
||||
messages: vec![
|
||||
AnthropicMessage {
|
||||
role: "user".to_string(),
|
||||
content: format!("Please summarize the following text:\n\n{}", request.text),
|
||||
}
|
||||
],
|
||||
system: Some(system_prompt),
|
||||
top_p: request.top_p,
|
||||
};
|
||||
|
||||
let response = self.client
|
||||
.post(format!("{}/messages", anthropic_config.base_url))
|
||||
.header("x-api-key", &anthropic_config.api_key)
|
||||
.header("anthropic-version", "2023-06-01")
|
||||
.header("content-type", "application/json")
|
||||
.json(&anthropic_request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let error_text = response.text().await?;
|
||||
return Err(anyhow!("Anthropic API error: {}", error_text));
|
||||
}
|
||||
|
||||
let anthropic_response: AnthropicResponse = response.json().await?;
|
||||
|
||||
let summary = anthropic_response.content
|
||||
.first()
|
||||
.map(|c| c.text.clone())
|
||||
.ok_or_else(|| anyhow!("No content in Anthropic response"))?;
|
||||
|
||||
let token_usage = TokenUsage {
|
||||
input_tokens: anthropic_response.usage.input_tokens,
|
||||
output_tokens: anthropic_response.usage.output_tokens,
|
||||
total_tokens: anthropic_response.usage.input_tokens + anthropic_response.usage.output_tokens,
|
||||
estimated_cost_usd: Some(calculate_anthropic_cost(
|
||||
&model,
|
||||
anthropic_response.usage.input_tokens,
|
||||
anthropic_response.usage.output_tokens,
|
||||
)),
|
||||
};
|
||||
|
||||
Ok((summary, model, token_usage))
|
||||
}
|
||||
|
||||
async fn summarize_ollama(&self, request: &SummarizeRequest) -> Result<(String, String, TokenUsage)> {
|
||||
let ollama_config = self.config.ollama.as_ref()
|
||||
.ok_or_else(|| anyhow!("Ollama config not found"))?;
|
||||
|
||||
let model = request.model.clone()
|
||||
.unwrap_or_else(|| ollama_config.model.clone());
|
||||
|
||||
let temperature = request.temperature
|
||||
.unwrap_or(ollama_config.temperature);
|
||||
|
||||
let system_prompt = if let Some(ref style) = request.style {
|
||||
style.to_system_prompt()
|
||||
} else {
|
||||
request.system_prompt.clone()
|
||||
.unwrap_or_else(|| "You are a helpful assistant that summarizes text concisely and accurately. Provide a clear, well-structured summary of the given text.".to_string())
|
||||
};
|
||||
|
||||
let prompt = format!("Please summarize the following text:\n\n{}", request.text);
|
||||
|
||||
let ollama_request = OllamaRequest {
|
||||
model: model.clone(),
|
||||
prompt,
|
||||
stream: false,
|
||||
system: Some(system_prompt),
|
||||
options: OllamaOptions {
|
||||
temperature,
|
||||
top_p: request.top_p,
|
||||
num_predict: request.max_tokens.map(|v| v as i32),
|
||||
},
|
||||
};
|
||||
|
||||
let response = self.client
|
||||
.post(format!("{}/api/generate", ollama_config.base_url))
|
||||
.json(&ollama_request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let error_text = response.text().await?;
|
||||
return Err(anyhow!("Ollama API error: {}", error_text));
|
||||
}
|
||||
|
||||
let ollama_response: OllamaResponse = response.json().await?;
|
||||
|
||||
let token_usage = TokenUsage {
|
||||
input_tokens: ollama_response.prompt_eval_count.unwrap_or(0),
|
||||
output_tokens: ollama_response.eval_count.unwrap_or(0),
|
||||
total_tokens: ollama_response.prompt_eval_count.unwrap_or(0) + ollama_response.eval_count.unwrap_or(0),
|
||||
estimated_cost_usd: None, // Ollama is local, no cost
|
||||
};
|
||||
|
||||
Ok((ollama_response.response, model, token_usage))
|
||||
}
|
||||
|
||||
pub async fn _list_ollama_models(&self) -> Result<Vec<_OllamaModelInfo>> {
|
||||
let ollama_config = self.config.ollama.as_ref()
|
||||
.ok_or_else(|| anyhow!("Ollama config not found"))?;
|
||||
|
||||
let response = self.client
|
||||
.get(format!("{}/api/tags", ollama_config.base_url))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(anyhow!("Failed to list Ollama models"));
|
||||
}
|
||||
|
||||
let list_response: _OllamaListResponse = response.json().await?;
|
||||
Ok(list_response.models)
|
||||
}
|
||||
|
||||
pub fn _get_anthropic_models(&self) -> Vec<String> {
|
||||
// Return known Claude models from claude_models module
|
||||
ClaudeModelType::all_available()
|
||||
.into_iter()
|
||||
.map(|m| m.id())
|
||||
.collect()
|
||||
}
|
||||
|
||||
async fn summarize_lmstudio(&self, request: &SummarizeRequest) -> Result<(String, String, TokenUsage)> {
|
||||
let lmstudio_config = self.config.lmstudio.as_ref()
|
||||
.ok_or_else(|| anyhow!("LMStudio config not found"))?;
|
||||
|
||||
let model = request.model.clone()
|
||||
.unwrap_or_else(|| lmstudio_config.model.clone());
|
||||
|
||||
let temperature = request.temperature
|
||||
.unwrap_or(lmstudio_config.temperature);
|
||||
|
||||
let max_tokens = request.max_tokens
|
||||
.unwrap_or(lmstudio_config.max_tokens);
|
||||
|
||||
let system_prompt = if let Some(ref style) = request.style {
|
||||
style.to_system_prompt()
|
||||
} else {
|
||||
request.system_prompt.clone()
|
||||
.unwrap_or_else(|| "You are a helpful assistant that summarizes text concisely and accurately. Provide a clear, well-structured summary of the given text.".to_string())
|
||||
};
|
||||
|
||||
// LMStudio uses OpenAI-compatible API
|
||||
#[derive(Serialize)]
|
||||
struct LmStudioRequest {
|
||||
model: String,
|
||||
messages: Vec<LmStudioMessage>,
|
||||
temperature: f32,
|
||||
max_tokens: u32,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct LmStudioMessage {
|
||||
role: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct LmStudioResponse {
|
||||
choices: Vec<LmStudioChoice>,
|
||||
usage: Option<LmStudioUsage>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct LmStudioChoice {
|
||||
message: LmStudioResponseMessage,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct LmStudioResponseMessage {
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct LmStudioUsage {
|
||||
prompt_tokens: u32,
|
||||
completion_tokens: u32,
|
||||
total_tokens: u32,
|
||||
}
|
||||
|
||||
let lmstudio_request = LmStudioRequest {
|
||||
model: model.clone(),
|
||||
messages: vec![
|
||||
LmStudioMessage {
|
||||
role: "system".to_string(),
|
||||
content: system_prompt,
|
||||
},
|
||||
LmStudioMessage {
|
||||
role: "user".to_string(),
|
||||
content: format!("Please summarize the following text:\n\n{}", request.text),
|
||||
},
|
||||
],
|
||||
temperature,
|
||||
max_tokens,
|
||||
};
|
||||
|
||||
let response = self.client
|
||||
.post(format!("{}/chat/completions", lmstudio_config.base_url))
|
||||
.json(&lmstudio_request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let error_text = response.text().await?;
|
||||
return Err(anyhow!("LMStudio API error: {}", error_text));
|
||||
}
|
||||
|
||||
let lmstudio_response: LmStudioResponse = response.json().await?;
|
||||
|
||||
let summary = lmstudio_response.choices
|
||||
.first()
|
||||
.map(|c| c.message.content.clone())
|
||||
.ok_or_else(|| anyhow!("No content in LMStudio response"))?;
|
||||
|
||||
let token_usage = if let Some(usage) = lmstudio_response.usage {
|
||||
TokenUsage {
|
||||
input_tokens: usage.prompt_tokens,
|
||||
output_tokens: usage.completion_tokens,
|
||||
total_tokens: usage.total_tokens,
|
||||
estimated_cost_usd: None, // LMStudio is local, no cost
|
||||
}
|
||||
} else {
|
||||
TokenUsage {
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
total_tokens: 0,
|
||||
estimated_cost_usd: None,
|
||||
}
|
||||
};
|
||||
|
||||
Ok((summary, model, token_usage))
|
||||
}
|
||||
|
||||
pub async fn _list_lmstudio_models(&self) -> Result<Vec<_OllamaModelInfo>> {
|
||||
let lmstudio_config = self.config.lmstudio.as_ref()
|
||||
.ok_or_else(|| anyhow!("LMStudio config not found"))?;
|
||||
|
||||
let response = self.client
|
||||
.get(format!("{}/models", lmstudio_config.base_url))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(anyhow!("Failed to list LMStudio models"));
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct LmStudioModelsResponse {
|
||||
data: Vec<LmStudioModel>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct LmStudioModel {
|
||||
id: String,
|
||||
}
|
||||
|
||||
let list_response: LmStudioModelsResponse = response.json().await?;
|
||||
|
||||
// Convert to OllamaModelInfo format for compatibility
|
||||
Ok(list_response.data.into_iter().map(|m| _OllamaModelInfo {
|
||||
name: m.id,
|
||||
modified_at: "".to_string(),
|
||||
size: 0,
|
||||
}).collect())
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate Anthropic API costs based on model and token usage using ClaudeModelType
|
||||
fn calculate_anthropic_cost(model: &str, input_tokens: u32, output_tokens: u32) -> f64 {
|
||||
let claude_model = ClaudeModelType::from_id(model);
|
||||
claude_model.calculate_cost(input_tokens, output_tokens)
|
||||
}
|
||||
Reference in New Issue
Block a user