Jesús Pérez dd68d190ef ci: Update pre-commit hooks configuration
- Exclude problematic markdown files from linting (existing legacy issues)
- Make clippy check less aggressive (warnings only, not -D warnings)
- Move cargo test to manual stage (too slow for pre-commit)
- Exclude SVG files from end-of-file-fixer and trailing-whitespace
- Add markdown linting exclusions for existing documentation

This allows pre-commit hooks to run successfully on new code without
blocking commits due to existing issues in legacy documentation files.
2026-01-11 21:32:56 +00:00

399 lines
11 KiB
Rust

// Embedding provider implementations for vector similarity in Knowledge Graph
// Phase 5.1: Embedding-based KG similarity
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use thiserror::Error;
use tracing::debug;
#[derive(Debug, Error)]
pub enum EmbeddingError {
#[error("Provider error: {0}")]
ProviderError(String),
#[error("Invalid input: {0}")]
InvalidInput(String),
#[error("Request failed: {0}")]
RequestFailed(String),
#[error("Configuration error: {0}")]
ConfigError(String),
#[error("HTTP error: {0}")]
HttpError(#[from] reqwest::Error),
#[error("JSON error: {0}")]
JsonError(#[from] serde_json::Error),
}
pub type Result<T> = std::result::Result<T, EmbeddingError>;
/// Trait for embedding providers - converts text to vector embeddings
#[async_trait]
pub trait EmbeddingProvider: Send + Sync {
/// Generate embedding for text (returns 1536-dim vector by default)
async fn embed(&self, text: &str) -> Result<Vec<f32>>;
/// Batch embed multiple texts (more efficient for providers)
async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
let mut results = Vec::new();
for text in texts {
results.push(self.embed(text).await?);
}
Ok(results)
}
/// Provider name for metrics/logging
fn provider_name(&self) -> &str;
/// Model name being used
fn model_name(&self) -> &str;
/// Embedding dimension (usually 1536)
fn embedding_dim(&self) -> usize {
1536
}
}
// ============================================================================
// Ollama Provider (Local, Free)
// ============================================================================
pub struct OllamaEmbedding {
endpoint: String,
model: String,
client: reqwest::Client,
}
impl OllamaEmbedding {
pub fn new(endpoint: String, model: String) -> Self {
Self {
endpoint,
model,
client: reqwest::Client::new(),
}
}
}
#[derive(Debug, Serialize)]
struct OllamaEmbedRequest {
model: String,
prompt: String,
}
#[derive(Debug, Deserialize)]
struct OllamaEmbedResponse {
embedding: Vec<f32>,
}
#[async_trait]
impl EmbeddingProvider for OllamaEmbedding {
async fn embed(&self, text: &str) -> Result<Vec<f32>> {
if text.is_empty() {
return Err(EmbeddingError::InvalidInput("Empty text".to_string()));
}
debug!("Embedding text via Ollama ({})", self.model);
let request = OllamaEmbedRequest {
model: self.model.clone(),
prompt: text.to_string(),
};
let response = self
.client
.post(format!("{}/api/embeddings", self.endpoint))
.json(&request)
.send()
.await
.map_err(|e| EmbeddingError::RequestFailed(e.to_string()))?;
if !response.status().is_success() {
return Err(EmbeddingError::RequestFailed(format!(
"Status: {}",
response.status()
)));
}
let data: OllamaEmbedResponse = response
.json()
.await
.map_err(|e| EmbeddingError::RequestFailed(e.to_string()))?;
Ok(data.embedding)
}
fn provider_name(&self) -> &str {
"ollama"
}
fn model_name(&self) -> &str {
&self.model
}
}
// ============================================================================
// OpenAI Provider (Paid, Fast)
// ============================================================================
pub struct OpenAIEmbedding {
api_key: String,
model: String,
client: reqwest::Client,
}
impl OpenAIEmbedding {
pub fn new(api_key: String, model: String) -> Self {
Self {
api_key,
model,
client: reqwest::Client::new(),
}
}
}
#[derive(Debug, Serialize)]
struct OpenAIEmbedRequest {
model: String,
input: String,
#[serde(skip_serializing_if = "Option::is_none")]
encoding_format: Option<String>,
}
#[derive(Debug, Deserialize)]
struct OpenAIEmbedResponse {
data: Vec<OpenAIEmbedData>,
}
#[derive(Debug, Deserialize)]
struct OpenAIEmbedData {
embedding: Vec<f32>,
}
#[async_trait]
impl EmbeddingProvider for OpenAIEmbedding {
async fn embed(&self, text: &str) -> Result<Vec<f32>> {
if text.is_empty() {
return Err(EmbeddingError::InvalidInput("Empty text".to_string()));
}
debug!("Embedding text via OpenAI ({})", self.model);
let request = OpenAIEmbedRequest {
model: self.model.clone(),
input: text.to_string(),
encoding_format: None,
};
let response = self
.client
.post("https://api.openai.com/v1/embeddings")
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&request)
.send()
.await
.map_err(|e| EmbeddingError::RequestFailed(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_default();
return Err(EmbeddingError::RequestFailed(format!(
"OpenAI API error {}: {}",
status, text
)));
}
let data: OpenAIEmbedResponse = response
.json()
.await
.map_err(|e| EmbeddingError::RequestFailed(e.to_string()))?;
if data.data.is_empty() {
return Err(EmbeddingError::RequestFailed(
"No embeddings in response".to_string(),
));
}
Ok(data.data[0].embedding.clone())
}
fn provider_name(&self) -> &str {
"openai"
}
fn model_name(&self) -> &str {
&self.model
}
}
// ============================================================================
// HuggingFace Provider (Free, Flexible)
// ============================================================================
pub struct HuggingFaceEmbedding {
api_key: String,
model: String,
client: reqwest::Client,
}
impl HuggingFaceEmbedding {
pub fn new(api_key: String, model: String) -> Self {
Self {
api_key,
model,
client: reqwest::Client::new(),
}
}
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum HFEmbedResponse {
Single(Vec<f32>),
Multiple(Vec<Vec<f32>>),
}
#[async_trait]
impl EmbeddingProvider for HuggingFaceEmbedding {
async fn embed(&self, text: &str) -> Result<Vec<f32>> {
if text.is_empty() {
return Err(EmbeddingError::InvalidInput("Empty text".to_string()));
}
debug!("Embedding text via HuggingFace ({})", self.model);
let response = self
.client
.post(format!(
"https://api-inference.huggingface.co/pipeline/feature-extraction/{}",
self.model
))
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&serde_json::json!({"inputs": text}))
.send()
.await
.map_err(|e| EmbeddingError::RequestFailed(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_default();
return Err(EmbeddingError::RequestFailed(format!(
"HuggingFace API error {}: {}",
status, text
)));
}
let data: HFEmbedResponse = response
.json()
.await
.map_err(|e| EmbeddingError::RequestFailed(e.to_string()))?;
match data {
HFEmbedResponse::Single(embedding) => Ok(embedding),
HFEmbedResponse::Multiple(embeddings) => {
if embeddings.is_empty() {
Err(EmbeddingError::RequestFailed(
"No embeddings in response".to_string(),
))
} else {
Ok(embeddings[0].clone())
}
}
}
}
fn provider_name(&self) -> &str {
"huggingface"
}
fn model_name(&self) -> &str {
&self.model
}
}
// ============================================================================
// Factory function to create providers from environment/config
// ============================================================================
pub async fn create_embedding_provider(provider_name: &str) -> Result<Arc<dyn EmbeddingProvider>> {
match provider_name.to_lowercase().as_str() {
"ollama" => {
let endpoint = std::env::var("OLLAMA_ENDPOINT")
.unwrap_or_else(|_| "http://localhost:11434".to_string());
let model = std::env::var("OLLAMA_EMBEDDING_MODEL")
.unwrap_or_else(|_| "nomic-embed-text".to_string());
debug!("Creating Ollama embedding provider: {}", model);
Ok(Arc::new(OllamaEmbedding::new(endpoint, model)))
}
"openai" => {
let api_key = std::env::var("OPENAI_API_KEY")
.map_err(|_| EmbeddingError::ConfigError("OPENAI_API_KEY not set".to_string()))?;
let model = std::env::var("OPENAI_EMBEDDING_MODEL")
.unwrap_or_else(|_| "text-embedding-3-small".to_string());
debug!("Creating OpenAI embedding provider: {}", model);
Ok(Arc::new(OpenAIEmbedding::new(api_key, model)))
}
"huggingface" => {
let api_key = std::env::var("HUGGINGFACE_API_KEY").map_err(|_| {
EmbeddingError::ConfigError("HUGGINGFACE_API_KEY not set".to_string())
})?;
let model = std::env::var("HUGGINGFACE_EMBEDDING_MODEL")
.unwrap_or_else(|_| "BAAI/bge-small-en-v1.5".to_string());
debug!("Creating HuggingFace embedding provider: {}", model);
Ok(Arc::new(HuggingFaceEmbedding::new(api_key, model)))
}
_ => Err(EmbeddingError::ConfigError(format!(
"Unknown embedding provider: {}",
provider_name
))),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ollama_provider_creation() {
let ollama = OllamaEmbedding::new(
"http://localhost:11434".to_string(),
"nomic-embed-text".to_string(),
);
assert_eq!(ollama.provider_name(), "ollama");
assert_eq!(ollama.model_name(), "nomic-embed-text");
assert_eq!(ollama.embedding_dim(), 1536);
}
#[test]
fn test_openai_provider_creation() {
let openai =
OpenAIEmbedding::new("test-key".to_string(), "text-embedding-3-small".to_string());
assert_eq!(openai.provider_name(), "openai");
assert_eq!(openai.model_name(), "text-embedding-3-small");
assert_eq!(openai.embedding_dim(), 1536);
}
#[test]
fn test_huggingface_provider_creation() {
let hf =
HuggingFaceEmbedding::new("test-key".to_string(), "BAAI/bge-small-en-v1.5".to_string());
assert_eq!(hf.provider_name(), "huggingface");
assert_eq!(hf.model_name(), "BAAI/bge-small-en-v1.5");
assert_eq!(hf.embedding_dim(), 1536);
}
#[test]
fn test_empty_text_error() {
let embedding_error = EmbeddingError::InvalidInput("Empty text".to_string());
assert!(embedding_error.to_string().contains("Empty text"));
}
}