#[cfg(feature = "huggingface-provider")] use std::time::Duration; #[cfg(feature = "huggingface-provider")] use async_trait::async_trait; #[cfg(feature = "huggingface-provider")] use reqwest::Client; #[cfg(feature = "huggingface-provider")] use serde::{Deserialize, Serialize}; #[cfg(feature = "huggingface-provider")] use crate::{ error::EmbeddingError, traits::{ normalize_embedding, Embedding, EmbeddingOptions, EmbeddingProvider, EmbeddingResult, }, }; #[derive(Debug, Clone, PartialEq)] pub enum HuggingFaceModel { /// BAAI/bge-small-en-v1.5 - 384 dimensions, efficient general-purpose BgeSmall, /// BAAI/bge-base-en-v1.5 - 768 dimensions, balanced performance BgeBase, /// BAAI/bge-large-en-v1.5 - 1024 dimensions, high quality BgeLarge, /// sentence-transformers/all-MiniLM-L6-v2 - 384 dimensions, fast AllMiniLm, /// sentence-transformers/all-mpnet-base-v2 - 768 dimensions, strong baseline AllMpnet, /// Custom model with model ID and dimensions Custom(String, usize), } impl Default for HuggingFaceModel { fn default() -> Self { Self::BgeSmall } } impl HuggingFaceModel { pub fn model_id(&self) -> &str { match self { Self::BgeSmall => "BAAI/bge-small-en-v1.5", Self::BgeBase => "BAAI/bge-base-en-v1.5", Self::BgeLarge => "BAAI/bge-large-en-v1.5", Self::AllMiniLm => "sentence-transformers/all-MiniLM-L6-v2", Self::AllMpnet => "sentence-transformers/all-mpnet-base-v2", Self::Custom(id, _) => id.as_str(), } } pub fn dimensions(&self) -> usize { match self { Self::BgeSmall => 384, Self::BgeBase => 768, Self::BgeLarge => 1024, Self::AllMiniLm => 384, Self::AllMpnet => 768, Self::Custom(_, dims) => *dims, } } pub fn from_model_id(id: &str, dimensions: Option) -> Self { match id { "BAAI/bge-small-en-v1.5" => Self::BgeSmall, "BAAI/bge-base-en-v1.5" => Self::BgeBase, "BAAI/bge-large-en-v1.5" => Self::BgeLarge, "sentence-transformers/all-MiniLM-L6-v2" => Self::AllMiniLm, "sentence-transformers/all-mpnet-base-v2" => Self::AllMpnet, _ => Self::Custom( id.to_string(), dimensions.unwrap_or(384), // Default to 384 if unknown ), } } } #[cfg(feature = "huggingface-provider")] pub struct HuggingFaceProvider { client: Client, api_key: String, model: HuggingFaceModel, } #[cfg(feature = "huggingface-provider")] #[derive(Debug, Serialize)] struct HFRequest { inputs: String, } #[cfg(feature = "huggingface-provider")] #[derive(Debug, Deserialize)] #[serde(untagged)] enum HFResponse { /// Single text embedding response Single(Vec), /// Batch embedding response Multiple(Vec>), } #[cfg(feature = "huggingface-provider")] impl HuggingFaceProvider { const BASE_URL: &'static str = "https://api-inference.huggingface.co/pipeline/feature-extraction"; pub fn new(api_key: impl Into, model: HuggingFaceModel) -> Result { let api_key = api_key.into(); if api_key.is_empty() { return Err(EmbeddingError::ConfigError( "HuggingFace API key is empty".to_string(), )); } let client = Client::builder() .timeout(Duration::from_secs(120)) .build() .map_err(|e| EmbeddingError::Initialization(e.to_string()))?; Ok(Self { client, api_key, model, }) } pub fn from_env(model: HuggingFaceModel) -> Result { let api_key = std::env::var("HUGGINGFACE_API_KEY") .or_else(|_| std::env::var("HF_TOKEN")) .map_err(|_| { EmbeddingError::ConfigError( "HUGGINGFACE_API_KEY or HF_TOKEN environment variable not set".to_string(), ) })?; Self::new(api_key, model) } pub fn bge_small() -> Result { Self::from_env(HuggingFaceModel::BgeSmall) } pub fn bge_base() -> Result { Self::from_env(HuggingFaceModel::BgeBase) } pub fn all_minilm() -> Result { Self::from_env(HuggingFaceModel::AllMiniLm) } } #[cfg(feature = "huggingface-provider")] #[async_trait] impl EmbeddingProvider for HuggingFaceProvider { fn name(&self) -> &str { "huggingface" } fn model(&self) -> &str { self.model.model_id() } fn dimensions(&self) -> usize { self.model.dimensions() } fn is_local(&self) -> bool { false } fn max_tokens(&self) -> usize { // HuggingFace doesn't specify a hard limit, but most models handle ~512 tokens 512 } fn max_batch_size(&self) -> usize { // HuggingFace Inference API doesn't support batch requests // Each request is individual 1 } fn cost_per_1m_tokens(&self) -> f64 { // HuggingFace Inference API is free for public models // For dedicated endpoints, costs vary 0.0 } async fn is_available(&self) -> bool { // HuggingFace Inference API is always available (with rate limits) true } async fn embed( &self, text: &str, options: &EmbeddingOptions, ) -> Result { if text.is_empty() { return Err(EmbeddingError::InvalidInput( "Text cannot be empty".to_string(), )); } let url = format!("{}/{}", Self::BASE_URL, self.model.model_id()); let request = HFRequest { inputs: text.to_string(), }; let response = self .client .post(&url) .header("Authorization", format!("Bearer {}", self.api_key)) .json(&request) .send() .await .map_err(|e| { EmbeddingError::ApiError(format!("HuggingFace API request failed: {}", e)) })?; if !response.status().is_success() { let status = response.status(); let error_text = response.text().await.unwrap_or_default(); return Err(EmbeddingError::ApiError(format!( "HuggingFace API error {}: {}", status, error_text ))); } let hf_response: HFResponse = response.json().await.map_err(|e| { EmbeddingError::ApiError(format!("Failed to parse HuggingFace response: {}", e)) })?; let mut embedding = match hf_response { HFResponse::Single(emb) => emb, HFResponse::Multiple(embs) => { if embs.is_empty() { return Err(EmbeddingError::ApiError( "Empty embeddings response from HuggingFace".to_string(), )); } embs[0].clone() } }; // Validate dimensions if embedding.len() != self.dimensions() { return Err(EmbeddingError::DimensionMismatch { expected: self.dimensions(), actual: embedding.len(), }); } // Normalize if requested if options.normalize { normalize_embedding(&mut embedding); } Ok(embedding) } async fn embed_batch( &self, texts: &[&str], options: &EmbeddingOptions, ) -> Result { if texts.is_empty() { return Err(EmbeddingError::InvalidInput( "Texts cannot be empty".to_string(), )); } // HuggingFace Inference API doesn't support true batch requests // We need to send individual requests let mut embeddings = Vec::with_capacity(texts.len()); for text in texts { let embedding = self.embed(text, options).await?; embeddings.push(embedding); } Ok(EmbeddingResult { embeddings, model: self.model().to_string(), dimensions: self.dimensions(), total_tokens: None, cached_count: 0, }) } } #[cfg(all(test, feature = "huggingface-provider"))] mod tests { use super::*; #[test] fn test_model_id_mapping() { assert_eq!( HuggingFaceModel::BgeSmall.model_id(), "BAAI/bge-small-en-v1.5" ); assert_eq!(HuggingFaceModel::BgeSmall.dimensions(), 384); assert_eq!( HuggingFaceModel::BgeBase.model_id(), "BAAI/bge-base-en-v1.5" ); assert_eq!(HuggingFaceModel::BgeBase.dimensions(), 768); } #[test] fn test_custom_model() { let custom = HuggingFaceModel::Custom("my-model".to_string(), 512); assert_eq!(custom.model_id(), "my-model"); assert_eq!(custom.dimensions(), 512); } #[test] fn test_from_model_id() { let model = HuggingFaceModel::from_model_id("BAAI/bge-small-en-v1.5", None); assert_eq!(model, HuggingFaceModel::BgeSmall); let custom = HuggingFaceModel::from_model_id("unknown-model", Some(256)); assert!(matches!(custom, HuggingFaceModel::Custom(_, 256))); } #[test] fn test_provider_creation() { let provider = HuggingFaceProvider::new("test-key", HuggingFaceModel::BgeSmall); assert!(provider.is_ok()); let provider = provider.unwrap(); assert_eq!(provider.name(), "huggingface"); assert_eq!(provider.model(), "BAAI/bge-small-en-v1.5"); assert_eq!(provider.dimensions(), 384); } #[test] fn test_empty_api_key() { let provider = HuggingFaceProvider::new("", HuggingFaceModel::BgeSmall); assert!(provider.is_err()); assert!(matches!( provider.unwrap_err(), EmbeddingError::ConfigError(_) )); } }