160 lines
4.1 KiB
Rust
Raw Normal View History

use async_trait::async_trait;
use crate::error::LlmError;
use crate::providers::{
GenerationOptions, GenerationResponse, LlmProvider, Message, StreamResponse,
};
#[cfg(feature = "ollama")]
pub struct OllamaProvider {
client: reqwest::Client,
base_url: String,
model: String,
}
#[cfg(feature = "ollama")]
impl OllamaProvider {
pub fn new(base_url: impl Into<String>, model: impl Into<String>) -> Self {
Self {
client: reqwest::Client::new(),
base_url: base_url.into(),
model: model.into(),
}
}
pub fn from_env(model: impl Into<String>) -> Self {
let base_url =
std::env::var("OLLAMA_HOST").unwrap_or_else(|_| "http://localhost:11434".to_string());
Self::new(base_url, model)
}
pub fn llama3() -> Self {
Self::from_env("llama3")
}
pub fn codellama() -> Self {
Self::from_env("codellama")
}
pub fn mistral() -> Self {
Self::from_env("mistral")
}
}
#[cfg(feature = "ollama")]
impl Default for OllamaProvider {
fn default() -> Self {
Self::llama3()
}
}
#[cfg(feature = "ollama")]
#[async_trait]
impl LlmProvider for OllamaProvider {
fn name(&self) -> &str {
"ollama"
}
fn model(&self) -> &str {
&self.model
}
async fn is_available(&self) -> bool {
self.client
.get(format!("{}/api/tags", self.base_url))
.send()
.await
.is_ok()
}
async fn generate(
&self,
messages: &[Message],
options: &GenerationOptions,
) -> Result<GenerationResponse, LlmError> {
let start = std::time::Instant::now();
let prompt = messages
.iter()
.map(|m| match m.role {
crate::providers::Role::System => format!("System: {}", m.content),
crate::providers::Role::User => format!("User: {}", m.content),
crate::providers::Role::Assistant => format!("Assistant: {}", m.content),
})
.collect::<Vec<_>>()
.join("\n\n");
let mut body = serde_json::json!({
"model": self.model,
"prompt": prompt,
"stream": false,
});
if let Some(temp) = options.temperature {
body["temperature"] = serde_json::json!(temp);
}
if let Some(top_p) = options.top_p {
body["top_p"] = serde_json::json!(top_p);
}
if !options.stop_sequences.is_empty() {
body["stop"] = serde_json::json!(options.stop_sequences);
}
let response = self
.client
.post(format!("{}/api/generate", self.base_url))
.json(&body)
.send()
.await
.map_err(|e| LlmError::Network(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_default();
return Err(LlmError::Api(format!("{}: {}", status, text)));
}
let json: serde_json::Value = response
.json()
.await
.map_err(|e| LlmError::Parse(e.to_string()))?;
let content = json["response"].as_str().unwrap_or("").to_string();
let input_tokens = prompt.len() as u32 / 4;
let output_tokens = content.len() as u32 / 4;
Ok(GenerationResponse {
content,
model: self.model.clone(),
provider: "ollama".to_string(),
input_tokens,
output_tokens,
cost_cents: 0.0,
latency_ms: start.elapsed().as_millis() as u64,
})
}
async fn stream(
&self,
_messages: &[Message],
_options: &GenerationOptions,
) -> Result<StreamResponse, LlmError> {
Err(LlmError::Unavailable(
"Streaming not yet implemented".to_string(),
))
}
fn estimate_cost(&self, _input_tokens: u32, _output_tokens: u32) -> f64 {
0.0
}
fn cost_per_1m_input(&self) -> f64 {
0.0
}
fn cost_per_1m_output(&self) -> f64 {
0.0
}
}