Jesús Pérez 9864f88c14
Some checks failed
Rust CI / Security Audit (push) Has been cancelled
Rust CI / Check + Test + Lint (nightly) (push) Has been cancelled
Rust CI / Check + Test + Lint (stable) (push) Has been cancelled
chore: add huggingface provider and examples
2026-01-24 02:25:05 +00:00

350 lines
10 KiB
Rust

#[cfg(feature = "huggingface-provider")]
use std::time::Duration;
#[cfg(feature = "huggingface-provider")]
use async_trait::async_trait;
#[cfg(feature = "huggingface-provider")]
use reqwest::Client;
#[cfg(feature = "huggingface-provider")]
use serde::{Deserialize, Serialize};
#[cfg(feature = "huggingface-provider")]
use crate::{
error::EmbeddingError,
traits::{
normalize_embedding, Embedding, EmbeddingOptions, EmbeddingProvider, EmbeddingResult,
},
};
#[derive(Debug, Clone, PartialEq)]
pub enum HuggingFaceModel {
/// BAAI/bge-small-en-v1.5 - 384 dimensions, efficient general-purpose
BgeSmall,
/// BAAI/bge-base-en-v1.5 - 768 dimensions, balanced performance
BgeBase,
/// BAAI/bge-large-en-v1.5 - 1024 dimensions, high quality
BgeLarge,
/// sentence-transformers/all-MiniLM-L6-v2 - 384 dimensions, fast
AllMiniLm,
/// sentence-transformers/all-mpnet-base-v2 - 768 dimensions, strong
/// baseline
AllMpnet,
/// Custom model with model ID and dimensions
Custom(String, usize),
}
impl Default for HuggingFaceModel {
fn default() -> Self {
Self::BgeSmall
}
}
impl HuggingFaceModel {
pub fn model_id(&self) -> &str {
match self {
Self::BgeSmall => "BAAI/bge-small-en-v1.5",
Self::BgeBase => "BAAI/bge-base-en-v1.5",
Self::BgeLarge => "BAAI/bge-large-en-v1.5",
Self::AllMiniLm => "sentence-transformers/all-MiniLM-L6-v2",
Self::AllMpnet => "sentence-transformers/all-mpnet-base-v2",
Self::Custom(id, _) => id.as_str(),
}
}
pub fn dimensions(&self) -> usize {
match self {
Self::BgeSmall => 384,
Self::BgeBase => 768,
Self::BgeLarge => 1024,
Self::AllMiniLm => 384,
Self::AllMpnet => 768,
Self::Custom(_, dims) => *dims,
}
}
pub fn from_model_id(id: &str, dimensions: Option<usize>) -> Self {
match id {
"BAAI/bge-small-en-v1.5" => Self::BgeSmall,
"BAAI/bge-base-en-v1.5" => Self::BgeBase,
"BAAI/bge-large-en-v1.5" => Self::BgeLarge,
"sentence-transformers/all-MiniLM-L6-v2" => Self::AllMiniLm,
"sentence-transformers/all-mpnet-base-v2" => Self::AllMpnet,
_ => Self::Custom(
id.to_string(),
dimensions.unwrap_or(384), // Default to 384 if unknown
),
}
}
}
#[cfg(feature = "huggingface-provider")]
pub struct HuggingFaceProvider {
client: Client,
api_key: String,
model: HuggingFaceModel,
}
#[cfg(feature = "huggingface-provider")]
#[derive(Debug, Serialize)]
struct HFRequest {
inputs: String,
}
#[cfg(feature = "huggingface-provider")]
#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum HFResponse {
/// Single text embedding response
Single(Vec<f32>),
/// Batch embedding response
Multiple(Vec<Vec<f32>>),
}
#[cfg(feature = "huggingface-provider")]
impl HuggingFaceProvider {
const BASE_URL: &'static str =
"https://api-inference.huggingface.co/pipeline/feature-extraction";
pub fn new(
api_key: impl Into<String>,
model: HuggingFaceModel,
) -> Result<Self, EmbeddingError> {
let api_key = api_key.into();
if api_key.is_empty() {
return Err(EmbeddingError::ConfigError(
"HuggingFace API key is empty".to_string(),
));
}
let client = Client::builder()
.timeout(Duration::from_secs(120))
.build()
.map_err(|e| EmbeddingError::Initialization(e.to_string()))?;
Ok(Self {
client,
api_key,
model,
})
}
pub fn from_env(model: HuggingFaceModel) -> Result<Self, EmbeddingError> {
let api_key = std::env::var("HUGGINGFACE_API_KEY")
.or_else(|_| std::env::var("HF_TOKEN"))
.map_err(|_| {
EmbeddingError::ConfigError(
"HUGGINGFACE_API_KEY or HF_TOKEN environment variable not set".to_string(),
)
})?;
Self::new(api_key, model)
}
pub fn bge_small() -> Result<Self, EmbeddingError> {
Self::from_env(HuggingFaceModel::BgeSmall)
}
pub fn bge_base() -> Result<Self, EmbeddingError> {
Self::from_env(HuggingFaceModel::BgeBase)
}
pub fn all_minilm() -> Result<Self, EmbeddingError> {
Self::from_env(HuggingFaceModel::AllMiniLm)
}
}
#[cfg(feature = "huggingface-provider")]
#[async_trait]
impl EmbeddingProvider for HuggingFaceProvider {
fn name(&self) -> &str {
"huggingface"
}
fn model(&self) -> &str {
self.model.model_id()
}
fn dimensions(&self) -> usize {
self.model.dimensions()
}
fn is_local(&self) -> bool {
false
}
fn max_tokens(&self) -> usize {
// HuggingFace doesn't specify a hard limit, but most models handle ~512 tokens
512
}
fn max_batch_size(&self) -> usize {
// HuggingFace Inference API doesn't support batch requests
// Each request is individual
1
}
fn cost_per_1m_tokens(&self) -> f64 {
// HuggingFace Inference API is free for public models
// For dedicated endpoints, costs vary
0.0
}
async fn is_available(&self) -> bool {
// HuggingFace Inference API is always available (with rate limits)
true
}
async fn embed(
&self,
text: &str,
options: &EmbeddingOptions,
) -> Result<Embedding, EmbeddingError> {
if text.is_empty() {
return Err(EmbeddingError::InvalidInput(
"Text cannot be empty".to_string(),
));
}
let url = format!("{}/{}", Self::BASE_URL, self.model.model_id());
let request = HFRequest {
inputs: text.to_string(),
};
let response = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&request)
.send()
.await
.map_err(|e| {
EmbeddingError::ApiError(format!("HuggingFace API request failed: {}", e))
})?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
return Err(EmbeddingError::ApiError(format!(
"HuggingFace API error {}: {}",
status, error_text
)));
}
let hf_response: HFResponse = response.json().await.map_err(|e| {
EmbeddingError::ApiError(format!("Failed to parse HuggingFace response: {}", e))
})?;
let mut embedding = match hf_response {
HFResponse::Single(emb) => emb,
HFResponse::Multiple(embs) => {
if embs.is_empty() {
return Err(EmbeddingError::ApiError(
"Empty embeddings response from HuggingFace".to_string(),
));
}
embs[0].clone()
}
};
// Validate dimensions
if embedding.len() != self.dimensions() {
return Err(EmbeddingError::DimensionMismatch {
expected: self.dimensions(),
actual: embedding.len(),
});
}
// Normalize if requested
if options.normalize {
normalize_embedding(&mut embedding);
}
Ok(embedding)
}
async fn embed_batch(
&self,
texts: &[&str],
options: &EmbeddingOptions,
) -> Result<EmbeddingResult, EmbeddingError> {
if texts.is_empty() {
return Err(EmbeddingError::InvalidInput(
"Texts cannot be empty".to_string(),
));
}
// HuggingFace Inference API doesn't support true batch requests
// We need to send individual requests
let mut embeddings = Vec::with_capacity(texts.len());
for text in texts {
let embedding = self.embed(text, options).await?;
embeddings.push(embedding);
}
Ok(EmbeddingResult {
embeddings,
model: self.model().to_string(),
dimensions: self.dimensions(),
total_tokens: None,
cached_count: 0,
})
}
}
#[cfg(all(test, feature = "huggingface-provider"))]
mod tests {
use super::*;
#[test]
fn test_model_id_mapping() {
assert_eq!(
HuggingFaceModel::BgeSmall.model_id(),
"BAAI/bge-small-en-v1.5"
);
assert_eq!(HuggingFaceModel::BgeSmall.dimensions(), 384);
assert_eq!(
HuggingFaceModel::BgeBase.model_id(),
"BAAI/bge-base-en-v1.5"
);
assert_eq!(HuggingFaceModel::BgeBase.dimensions(), 768);
}
#[test]
fn test_custom_model() {
let custom = HuggingFaceModel::Custom("my-model".to_string(), 512);
assert_eq!(custom.model_id(), "my-model");
assert_eq!(custom.dimensions(), 512);
}
#[test]
fn test_from_model_id() {
let model = HuggingFaceModel::from_model_id("BAAI/bge-small-en-v1.5", None);
assert_eq!(model, HuggingFaceModel::BgeSmall);
let custom = HuggingFaceModel::from_model_id("unknown-model", Some(256));
assert!(matches!(custom, HuggingFaceModel::Custom(_, 256)));
}
#[test]
fn test_provider_creation() {
let provider = HuggingFaceProvider::new("test-key", HuggingFaceModel::BgeSmall);
assert!(provider.is_ok());
let provider = provider.unwrap();
assert_eq!(provider.name(), "huggingface");
assert_eq!(provider.model(), "BAAI/bge-small-en-v1.5");
assert_eq!(provider.dimensions(), 384);
}
#[test]
fn test_empty_api_key() {
let provider = HuggingFaceProvider::new("", HuggingFaceModel::BgeSmall);
assert!(provider.is_err());
assert!(matches!(
provider.unwrap_err(),
EmbeddingError::ConfigError(_)
));
}
}