160 lines
4.1 KiB
Rust
160 lines
4.1 KiB
Rust
use async_trait::async_trait;
|
|
|
|
use crate::error::LlmError;
|
|
use crate::providers::{
|
|
GenerationOptions, GenerationResponse, LlmProvider, Message, StreamResponse,
|
|
};
|
|
|
|
#[cfg(feature = "ollama")]
|
|
pub struct OllamaProvider {
|
|
client: reqwest::Client,
|
|
base_url: String,
|
|
model: String,
|
|
}
|
|
|
|
#[cfg(feature = "ollama")]
|
|
impl OllamaProvider {
|
|
pub fn new(base_url: impl Into<String>, model: impl Into<String>) -> Self {
|
|
Self {
|
|
client: reqwest::Client::new(),
|
|
base_url: base_url.into(),
|
|
model: model.into(),
|
|
}
|
|
}
|
|
|
|
pub fn from_env(model: impl Into<String>) -> Self {
|
|
let base_url =
|
|
std::env::var("OLLAMA_HOST").unwrap_or_else(|_| "http://localhost:11434".to_string());
|
|
Self::new(base_url, model)
|
|
}
|
|
|
|
pub fn llama3() -> Self {
|
|
Self::from_env("llama3")
|
|
}
|
|
|
|
pub fn codellama() -> Self {
|
|
Self::from_env("codellama")
|
|
}
|
|
|
|
pub fn mistral() -> Self {
|
|
Self::from_env("mistral")
|
|
}
|
|
}
|
|
|
|
#[cfg(feature = "ollama")]
|
|
impl Default for OllamaProvider {
|
|
fn default() -> Self {
|
|
Self::llama3()
|
|
}
|
|
}
|
|
|
|
#[cfg(feature = "ollama")]
|
|
#[async_trait]
|
|
impl LlmProvider for OllamaProvider {
|
|
fn name(&self) -> &str {
|
|
"ollama"
|
|
}
|
|
|
|
fn model(&self) -> &str {
|
|
&self.model
|
|
}
|
|
|
|
async fn is_available(&self) -> bool {
|
|
self.client
|
|
.get(format!("{}/api/tags", self.base_url))
|
|
.send()
|
|
.await
|
|
.is_ok()
|
|
}
|
|
|
|
async fn generate(
|
|
&self,
|
|
messages: &[Message],
|
|
options: &GenerationOptions,
|
|
) -> Result<GenerationResponse, LlmError> {
|
|
let start = std::time::Instant::now();
|
|
|
|
let prompt = messages
|
|
.iter()
|
|
.map(|m| match m.role {
|
|
crate::providers::Role::System => format!("System: {}", m.content),
|
|
crate::providers::Role::User => format!("User: {}", m.content),
|
|
crate::providers::Role::Assistant => format!("Assistant: {}", m.content),
|
|
})
|
|
.collect::<Vec<_>>()
|
|
.join("\n\n");
|
|
|
|
let mut body = serde_json::json!({
|
|
"model": self.model,
|
|
"prompt": prompt,
|
|
"stream": false,
|
|
});
|
|
|
|
if let Some(temp) = options.temperature {
|
|
body["temperature"] = serde_json::json!(temp);
|
|
}
|
|
if let Some(top_p) = options.top_p {
|
|
body["top_p"] = serde_json::json!(top_p);
|
|
}
|
|
if !options.stop_sequences.is_empty() {
|
|
body["stop"] = serde_json::json!(options.stop_sequences);
|
|
}
|
|
|
|
let response = self
|
|
.client
|
|
.post(format!("{}/api/generate", self.base_url))
|
|
.json(&body)
|
|
.send()
|
|
.await
|
|
.map_err(|e| LlmError::Network(e.to_string()))?;
|
|
|
|
if !response.status().is_success() {
|
|
let status = response.status();
|
|
let text = response.text().await.unwrap_or_default();
|
|
return Err(LlmError::Api(format!("{}: {}", status, text)));
|
|
}
|
|
|
|
let json: serde_json::Value = response
|
|
.json()
|
|
.await
|
|
.map_err(|e| LlmError::Parse(e.to_string()))?;
|
|
|
|
let content = json["response"].as_str().unwrap_or("").to_string();
|
|
|
|
let input_tokens = prompt.len() as u32 / 4;
|
|
let output_tokens = content.len() as u32 / 4;
|
|
|
|
Ok(GenerationResponse {
|
|
content,
|
|
model: self.model.clone(),
|
|
provider: "ollama".to_string(),
|
|
input_tokens,
|
|
output_tokens,
|
|
cost_cents: 0.0,
|
|
latency_ms: start.elapsed().as_millis() as u64,
|
|
})
|
|
}
|
|
|
|
async fn stream(
|
|
&self,
|
|
_messages: &[Message],
|
|
_options: &GenerationOptions,
|
|
) -> Result<StreamResponse, LlmError> {
|
|
Err(LlmError::Unavailable(
|
|
"Streaming not yet implemented".to_string(),
|
|
))
|
|
}
|
|
|
|
fn estimate_cost(&self, _input_tokens: u32, _output_tokens: u32) -> f64 {
|
|
0.0
|
|
}
|
|
|
|
fn cost_per_1m_input(&self) -> f64 {
|
|
0.0
|
|
}
|
|
|
|
fn cost_per_1m_output(&self) -> f64 {
|
|
0.0
|
|
}
|
|
}
|