345 lines
10 KiB
Rust
345 lines
10 KiB
Rust
|
|
#[cfg(feature = "huggingface-provider")]
|
||
|
|
use std::time::Duration;
|
||
|
|
|
||
|
|
#[cfg(feature = "huggingface-provider")]
|
||
|
|
use async_trait::async_trait;
|
||
|
|
#[cfg(feature = "huggingface-provider")]
|
||
|
|
use reqwest::Client;
|
||
|
|
#[cfg(feature = "huggingface-provider")]
|
||
|
|
use serde::{Deserialize, Serialize};
|
||
|
|
|
||
|
|
#[cfg(feature = "huggingface-provider")]
|
||
|
|
use crate::{
|
||
|
|
error::EmbeddingError,
|
||
|
|
traits::{
|
||
|
|
normalize_embedding, Embedding, EmbeddingOptions, EmbeddingProvider, EmbeddingResult,
|
||
|
|
},
|
||
|
|
};
|
||
|
|
|
||
|
|
#[derive(Debug, Clone, PartialEq)]
|
||
|
|
pub enum HuggingFaceModel {
|
||
|
|
/// BAAI/bge-small-en-v1.5 - 384 dimensions, efficient general-purpose
|
||
|
|
BgeSmall,
|
||
|
|
/// BAAI/bge-base-en-v1.5 - 768 dimensions, balanced performance
|
||
|
|
BgeBase,
|
||
|
|
/// BAAI/bge-large-en-v1.5 - 1024 dimensions, high quality
|
||
|
|
BgeLarge,
|
||
|
|
/// sentence-transformers/all-MiniLM-L6-v2 - 384 dimensions, fast
|
||
|
|
AllMiniLm,
|
||
|
|
/// sentence-transformers/all-mpnet-base-v2 - 768 dimensions, strong baseline
|
||
|
|
AllMpnet,
|
||
|
|
/// Custom model with model ID and dimensions
|
||
|
|
Custom(String, usize),
|
||
|
|
}
|
||
|
|
|
||
|
|
impl Default for HuggingFaceModel {
|
||
|
|
fn default() -> Self {
|
||
|
|
Self::BgeSmall
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
impl HuggingFaceModel {
|
||
|
|
pub fn model_id(&self) -> &str {
|
||
|
|
match self {
|
||
|
|
Self::BgeSmall => "BAAI/bge-small-en-v1.5",
|
||
|
|
Self::BgeBase => "BAAI/bge-base-en-v1.5",
|
||
|
|
Self::BgeLarge => "BAAI/bge-large-en-v1.5",
|
||
|
|
Self::AllMiniLm => "sentence-transformers/all-MiniLM-L6-v2",
|
||
|
|
Self::AllMpnet => "sentence-transformers/all-mpnet-base-v2",
|
||
|
|
Self::Custom(id, _) => id.as_str(),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn dimensions(&self) -> usize {
|
||
|
|
match self {
|
||
|
|
Self::BgeSmall => 384,
|
||
|
|
Self::BgeBase => 768,
|
||
|
|
Self::BgeLarge => 1024,
|
||
|
|
Self::AllMiniLm => 384,
|
||
|
|
Self::AllMpnet => 768,
|
||
|
|
Self::Custom(_, dims) => *dims,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn from_model_id(id: &str, dimensions: Option<usize>) -> Self {
|
||
|
|
match id {
|
||
|
|
"BAAI/bge-small-en-v1.5" => Self::BgeSmall,
|
||
|
|
"BAAI/bge-base-en-v1.5" => Self::BgeBase,
|
||
|
|
"BAAI/bge-large-en-v1.5" => Self::BgeLarge,
|
||
|
|
"sentence-transformers/all-MiniLM-L6-v2" => Self::AllMiniLm,
|
||
|
|
"sentence-transformers/all-mpnet-base-v2" => Self::AllMpnet,
|
||
|
|
_ => Self::Custom(
|
||
|
|
id.to_string(),
|
||
|
|
dimensions.unwrap_or(384), // Default to 384 if unknown
|
||
|
|
),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
#[cfg(feature = "huggingface-provider")]
|
||
|
|
pub struct HuggingFaceProvider {
|
||
|
|
client: Client,
|
||
|
|
api_key: String,
|
||
|
|
model: HuggingFaceModel,
|
||
|
|
}
|
||
|
|
|
||
|
|
#[cfg(feature = "huggingface-provider")]
|
||
|
|
#[derive(Debug, Serialize)]
|
||
|
|
struct HFRequest {
|
||
|
|
inputs: String,
|
||
|
|
}
|
||
|
|
|
||
|
|
#[cfg(feature = "huggingface-provider")]
|
||
|
|
#[derive(Debug, Deserialize)]
|
||
|
|
#[serde(untagged)]
|
||
|
|
enum HFResponse {
|
||
|
|
/// Single text embedding response
|
||
|
|
Single(Vec<f32>),
|
||
|
|
/// Batch embedding response
|
||
|
|
Multiple(Vec<Vec<f32>>),
|
||
|
|
}
|
||
|
|
|
||
|
|
#[cfg(feature = "huggingface-provider")]
|
||
|
|
impl HuggingFaceProvider {
|
||
|
|
const BASE_URL: &'static str = "https://api-inference.huggingface.co/pipeline/feature-extraction";
|
||
|
|
|
||
|
|
pub fn new(api_key: impl Into<String>, model: HuggingFaceModel) -> Result<Self, EmbeddingError> {
|
||
|
|
let api_key = api_key.into();
|
||
|
|
if api_key.is_empty() {
|
||
|
|
return Err(EmbeddingError::ConfigError(
|
||
|
|
"HuggingFace API key is empty".to_string(),
|
||
|
|
));
|
||
|
|
}
|
||
|
|
|
||
|
|
let client = Client::builder()
|
||
|
|
.timeout(Duration::from_secs(120))
|
||
|
|
.build()
|
||
|
|
.map_err(|e| EmbeddingError::Initialization(e.to_string()))?;
|
||
|
|
|
||
|
|
Ok(Self {
|
||
|
|
client,
|
||
|
|
api_key,
|
||
|
|
model,
|
||
|
|
})
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn from_env(model: HuggingFaceModel) -> Result<Self, EmbeddingError> {
|
||
|
|
let api_key = std::env::var("HUGGINGFACE_API_KEY")
|
||
|
|
.or_else(|_| std::env::var("HF_TOKEN"))
|
||
|
|
.map_err(|_| {
|
||
|
|
EmbeddingError::ConfigError(
|
||
|
|
"HUGGINGFACE_API_KEY or HF_TOKEN environment variable not set".to_string(),
|
||
|
|
)
|
||
|
|
})?;
|
||
|
|
Self::new(api_key, model)
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn bge_small() -> Result<Self, EmbeddingError> {
|
||
|
|
Self::from_env(HuggingFaceModel::BgeSmall)
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn bge_base() -> Result<Self, EmbeddingError> {
|
||
|
|
Self::from_env(HuggingFaceModel::BgeBase)
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn all_minilm() -> Result<Self, EmbeddingError> {
|
||
|
|
Self::from_env(HuggingFaceModel::AllMiniLm)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
#[cfg(feature = "huggingface-provider")]
|
||
|
|
#[async_trait]
|
||
|
|
impl EmbeddingProvider for HuggingFaceProvider {
|
||
|
|
fn name(&self) -> &str {
|
||
|
|
"huggingface"
|
||
|
|
}
|
||
|
|
|
||
|
|
fn model(&self) -> &str {
|
||
|
|
self.model.model_id()
|
||
|
|
}
|
||
|
|
|
||
|
|
fn dimensions(&self) -> usize {
|
||
|
|
self.model.dimensions()
|
||
|
|
}
|
||
|
|
|
||
|
|
fn is_local(&self) -> bool {
|
||
|
|
false
|
||
|
|
}
|
||
|
|
|
||
|
|
fn max_tokens(&self) -> usize {
|
||
|
|
// HuggingFace doesn't specify a hard limit, but most models handle ~512 tokens
|
||
|
|
512
|
||
|
|
}
|
||
|
|
|
||
|
|
fn max_batch_size(&self) -> usize {
|
||
|
|
// HuggingFace Inference API doesn't support batch requests
|
||
|
|
// Each request is individual
|
||
|
|
1
|
||
|
|
}
|
||
|
|
|
||
|
|
fn cost_per_1m_tokens(&self) -> f64 {
|
||
|
|
// HuggingFace Inference API is free for public models
|
||
|
|
// For dedicated endpoints, costs vary
|
||
|
|
0.0
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn is_available(&self) -> bool {
|
||
|
|
// HuggingFace Inference API is always available (with rate limits)
|
||
|
|
true
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn embed(
|
||
|
|
&self,
|
||
|
|
text: &str,
|
||
|
|
options: &EmbeddingOptions,
|
||
|
|
) -> Result<Embedding, EmbeddingError> {
|
||
|
|
if text.is_empty() {
|
||
|
|
return Err(EmbeddingError::InvalidInput(
|
||
|
|
"Text cannot be empty".to_string(),
|
||
|
|
));
|
||
|
|
}
|
||
|
|
|
||
|
|
let url = format!("{}/{}", Self::BASE_URL, self.model.model_id());
|
||
|
|
let request = HFRequest {
|
||
|
|
inputs: text.to_string(),
|
||
|
|
};
|
||
|
|
|
||
|
|
let response = self
|
||
|
|
.client
|
||
|
|
.post(&url)
|
||
|
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
||
|
|
.json(&request)
|
||
|
|
.send()
|
||
|
|
.await
|
||
|
|
.map_err(|e| {
|
||
|
|
EmbeddingError::ApiError(format!("HuggingFace API request failed: {}", e))
|
||
|
|
})?;
|
||
|
|
|
||
|
|
if !response.status().is_success() {
|
||
|
|
let status = response.status();
|
||
|
|
let error_text = response.text().await.unwrap_or_default();
|
||
|
|
return Err(EmbeddingError::ApiError(format!(
|
||
|
|
"HuggingFace API error {}: {}",
|
||
|
|
status, error_text
|
||
|
|
)));
|
||
|
|
}
|
||
|
|
|
||
|
|
let hf_response: HFResponse = response.json().await.map_err(|e| {
|
||
|
|
EmbeddingError::ApiError(format!("Failed to parse HuggingFace response: {}", e))
|
||
|
|
})?;
|
||
|
|
|
||
|
|
let mut embedding = match hf_response {
|
||
|
|
HFResponse::Single(emb) => emb,
|
||
|
|
HFResponse::Multiple(embs) => {
|
||
|
|
if embs.is_empty() {
|
||
|
|
return Err(EmbeddingError::ApiError(
|
||
|
|
"Empty embeddings response from HuggingFace".to_string(),
|
||
|
|
));
|
||
|
|
}
|
||
|
|
embs[0].clone()
|
||
|
|
}
|
||
|
|
};
|
||
|
|
|
||
|
|
// Validate dimensions
|
||
|
|
if embedding.len() != self.dimensions() {
|
||
|
|
return Err(EmbeddingError::DimensionMismatch {
|
||
|
|
expected: self.dimensions(),
|
||
|
|
actual: embedding.len(),
|
||
|
|
});
|
||
|
|
}
|
||
|
|
|
||
|
|
// Normalize if requested
|
||
|
|
if options.normalize {
|
||
|
|
normalize_embedding(&mut embedding);
|
||
|
|
}
|
||
|
|
|
||
|
|
Ok(embedding)
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn embed_batch(
|
||
|
|
&self,
|
||
|
|
texts: &[&str],
|
||
|
|
options: &EmbeddingOptions,
|
||
|
|
) -> Result<EmbeddingResult, EmbeddingError> {
|
||
|
|
if texts.is_empty() {
|
||
|
|
return Err(EmbeddingError::InvalidInput(
|
||
|
|
"Texts cannot be empty".to_string(),
|
||
|
|
));
|
||
|
|
}
|
||
|
|
|
||
|
|
// HuggingFace Inference API doesn't support true batch requests
|
||
|
|
// We need to send individual requests
|
||
|
|
let mut embeddings = Vec::with_capacity(texts.len());
|
||
|
|
|
||
|
|
for text in texts {
|
||
|
|
let embedding = self.embed(text, options).await?;
|
||
|
|
embeddings.push(embedding);
|
||
|
|
}
|
||
|
|
|
||
|
|
Ok(EmbeddingResult {
|
||
|
|
embeddings,
|
||
|
|
model: self.model().to_string(),
|
||
|
|
dimensions: self.dimensions(),
|
||
|
|
total_tokens: None,
|
||
|
|
cached_count: 0,
|
||
|
|
})
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
#[cfg(all(test, feature = "huggingface-provider"))]
|
||
|
|
mod tests {
|
||
|
|
use super::*;
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_model_id_mapping() {
|
||
|
|
assert_eq!(
|
||
|
|
HuggingFaceModel::BgeSmall.model_id(),
|
||
|
|
"BAAI/bge-small-en-v1.5"
|
||
|
|
);
|
||
|
|
assert_eq!(HuggingFaceModel::BgeSmall.dimensions(), 384);
|
||
|
|
|
||
|
|
assert_eq!(
|
||
|
|
HuggingFaceModel::BgeBase.model_id(),
|
||
|
|
"BAAI/bge-base-en-v1.5"
|
||
|
|
);
|
||
|
|
assert_eq!(HuggingFaceModel::BgeBase.dimensions(), 768);
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_custom_model() {
|
||
|
|
let custom = HuggingFaceModel::Custom("my-model".to_string(), 512);
|
||
|
|
assert_eq!(custom.model_id(), "my-model");
|
||
|
|
assert_eq!(custom.dimensions(), 512);
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_from_model_id() {
|
||
|
|
let model = HuggingFaceModel::from_model_id("BAAI/bge-small-en-v1.5", None);
|
||
|
|
assert_eq!(model, HuggingFaceModel::BgeSmall);
|
||
|
|
|
||
|
|
let custom = HuggingFaceModel::from_model_id("unknown-model", Some(256));
|
||
|
|
assert!(matches!(custom, HuggingFaceModel::Custom(_, 256)));
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_provider_creation() {
|
||
|
|
let provider = HuggingFaceProvider::new("test-key", HuggingFaceModel::BgeSmall);
|
||
|
|
assert!(provider.is_ok());
|
||
|
|
|
||
|
|
let provider = provider.unwrap();
|
||
|
|
assert_eq!(provider.name(), "huggingface");
|
||
|
|
assert_eq!(provider.model(), "BAAI/bge-small-en-v1.5");
|
||
|
|
assert_eq!(provider.dimensions(), 384);
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_empty_api_key() {
|
||
|
|
let provider = HuggingFaceProvider::new("", HuggingFaceModel::BgeSmall);
|
||
|
|
assert!(provider.is_err());
|
||
|
|
assert!(matches!(
|
||
|
|
provider.unwrap_err(),
|
||
|
|
EmbeddingError::ConfigError(_)
|
||
|
|
));
|
||
|
|
}
|
||
|
|
}
|