107 lines
2.8 KiB
Rust
Raw Normal View History

//! LLM integration using stratum-llm
use stratum_llm::{
AnthropicProvider, ConfiguredProvider, CredentialSource, GenerationOptions, Message,
ProviderChain, Role, UnifiedClient,
};
use tracing::info;
use crate::error::Result;
pub struct LlmClient {
client: UnifiedClient,
pub model: String,
}
impl LlmClient {
pub fn new(model: String) -> Result<Self> {
let api_key = std::env::var("ANTHROPIC_API_KEY").ok();
if api_key.is_none() {
tracing::warn!("ANTHROPIC_API_KEY not set - LLM calls will fail");
}
let provider = AnthropicProvider::new(api_key.unwrap_or_default(), model.clone());
let configured = ConfiguredProvider {
provider: Box::new(provider),
credential_source: CredentialSource::EnvVar {
name: "ANTHROPIC_API_KEY".to_string(),
},
priority: 0,
};
let chain = ProviderChain::with_providers(vec![configured]);
let client = UnifiedClient::builder()
.with_chain(chain)
.build()
.map_err(|e| {
crate::error::RagError::LlmError(format!("Failed to build LLM client: {}", e))
})?;
info!("Initialized stratum-llm client: {}", model);
Ok(Self { client, model })
}
pub async fn generate_answer(&self, query: &str, context: &str) -> Result<String> {
let system_prompt = format!(
r#"You are a helpful assistant answering questions about a provisioning platform.
You have been provided with relevant documentation context below.
Answer the user's question based on this context.
Be concise and accurate.
# Retrieved Context
{}
"#,
context
);
let messages = vec![
Message {
role: Role::System,
content: system_prompt,
},
Message {
role: Role::User,
content: query.to_string(),
},
];
let options = GenerationOptions {
max_tokens: Some(1024),
..Default::default()
};
let response = self
.client
.generate(&messages, Some(&options))
.await
.map_err(|e| {
crate::error::RagError::LlmError(format!("LLM generation failed: {}", e))
})?;
info!("Generated answer: {} characters", response.content.len());
Ok(response.content)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_llm_client_creation() {
let client = LlmClient::new("claude-opus-4".to_string());
assert!(client.is_ok());
}
#[test]
fn test_llm_client_model() {
let client = LlmClient::new("claude-sonnet-4".to_string());
assert!(client.is_ok());
assert_eq!(client.unwrap().model, "claude-sonnet-4");
}
}