use async_trait::async_trait; use crate::error::LlmError; use crate::providers::{ GenerationOptions, GenerationResponse, LlmProvider, Message, StreamResponse, }; #[cfg(feature = "ollama")] pub struct OllamaProvider { client: reqwest::Client, base_url: String, model: String, } #[cfg(feature = "ollama")] impl OllamaProvider { pub fn new(base_url: impl Into, model: impl Into) -> Self { Self { client: reqwest::Client::new(), base_url: base_url.into(), model: model.into(), } } pub fn from_env(model: impl Into) -> Self { let base_url = std::env::var("OLLAMA_HOST").unwrap_or_else(|_| "http://localhost:11434".to_string()); Self::new(base_url, model) } pub fn llama3() -> Self { Self::from_env("llama3") } pub fn codellama() -> Self { Self::from_env("codellama") } pub fn mistral() -> Self { Self::from_env("mistral") } } #[cfg(feature = "ollama")] impl Default for OllamaProvider { fn default() -> Self { Self::llama3() } } #[cfg(feature = "ollama")] #[async_trait] impl LlmProvider for OllamaProvider { fn name(&self) -> &str { "ollama" } fn model(&self) -> &str { &self.model } async fn is_available(&self) -> bool { self.client .get(format!("{}/api/tags", self.base_url)) .send() .await .is_ok() } async fn generate( &self, messages: &[Message], options: &GenerationOptions, ) -> Result { let start = std::time::Instant::now(); let prompt = messages .iter() .map(|m| match m.role { crate::providers::Role::System => format!("System: {}", m.content), crate::providers::Role::User => format!("User: {}", m.content), crate::providers::Role::Assistant => format!("Assistant: {}", m.content), }) .collect::>() .join("\n\n"); let mut body = serde_json::json!({ "model": self.model, "prompt": prompt, "stream": false, }); if let Some(temp) = options.temperature { body["temperature"] = serde_json::json!(temp); } if let Some(top_p) = options.top_p { body["top_p"] = serde_json::json!(top_p); } if !options.stop_sequences.is_empty() { body["stop"] = serde_json::json!(options.stop_sequences); } let response = self .client .post(format!("{}/api/generate", self.base_url)) .json(&body) .send() .await .map_err(|e| LlmError::Network(e.to_string()))?; if !response.status().is_success() { let status = response.status(); let text = response.text().await.unwrap_or_default(); return Err(LlmError::Api(format!("{}: {}", status, text))); } let json: serde_json::Value = response .json() .await .map_err(|e| LlmError::Parse(e.to_string()))?; let content = json["response"].as_str().unwrap_or("").to_string(); let input_tokens = prompt.len() as u32 / 4; let output_tokens = content.len() as u32 / 4; Ok(GenerationResponse { content, model: self.model.clone(), provider: "ollama".to_string(), input_tokens, output_tokens, cost_cents: 0.0, latency_ms: start.elapsed().as_millis() as u64, }) } async fn stream( &self, _messages: &[Message], _options: &GenerationOptions, ) -> Result { Err(LlmError::Unavailable( "Streaming not yet implemented".to_string(), )) } fn estimate_cost(&self, _input_tokens: u32, _output_tokens: u32) -> f64 { 0.0 } fn cost_per_1m_input(&self) -> f64 { 0.0 } fn cost_per_1m_output(&self) -> f64 { 0.0 } }