chore: add huggingface provider and examples
Some checks failed
Rust CI / Security Audit (push) Has been cancelled
Rust CI / Check + Test + Lint (nightly) (push) Has been cancelled
Rust CI / Check + Test + Lint (stable) (push) Has been cancelled

This commit is contained in:
Jesús Pérez 2026-01-24 02:25:05 +00:00
parent 0ae853c2fa
commit 9864f88c14
Signed by: jesus
GPG Key ID: 9F243E355E0BC939
2 changed files with 20 additions and 15 deletions

View File

@ -1,8 +1,6 @@
use std::time::Duration;
use stratum_embeddings::{
EmbeddingOptions, HuggingFaceModel, HuggingFaceProvider, MemoryCache,
};
use stratum_embeddings::{EmbeddingOptions, HuggingFaceModel, HuggingFaceProvider, MemoryCache};
use tracing::info;
#[tokio::main]
@ -48,10 +46,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let provider = HuggingFaceProvider::new(api_key, custom_model)?;
let embedding = provider.embed(text, &options).await?;
info!(
"Custom model embedding: {} dimensions",
embedding.len()
);
info!("Custom model embedding: {} dimensions", embedding.len());
// Example 4: Batch embeddings (sequential requests to HF API)
info!("\n4. Batch embedding (sequential API calls)");
@ -74,9 +69,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Example 5: Using with cache
info!("\n5. Demonstrating cache effectiveness");
let cache = MemoryCache::new(1000, Duration::from_secs(300));
let service = stratum_embeddings::EmbeddingService::new(
HuggingFaceProvider::bge_small()?
).with_cache(cache);
let service = stratum_embeddings::EmbeddingService::new(HuggingFaceProvider::bge_small()?)
.with_cache(cache);
let cached_options = EmbeddingOptions::default_with_cache();
@ -91,7 +85,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let _ = service.embed(text, &cached_options).await?;
let second_duration = start.elapsed();
info!("Second call (cache hit): {:?}", second_duration);
info!("Speedup: {:.2}x", first_duration.as_secs_f64() / second_duration.as_secs_f64());
info!(
"Speedup: {:.2}x",
first_duration.as_secs_f64() / second_duration.as_secs_f64()
);
info!("Cache size: {}", service.cache_size());
// Example 6: Normalized embeddings for similarity search
@ -117,7 +114,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
info!("Query: '{}'", query);
info!("Similarity with doc1 ('{}'): {:.4}", doc1, sim1);
info!("Similarity with doc2 ('{}'): {:.4}", doc2, sim2);
info!("Most similar: {}", if sim1 > sim2 { "doc1" } else { "doc2" });
info!(
"Most similar: {}",
if sim1 > sim2 { "doc1" } else { "doc2" }
);
info!("\n=== Demo Complete ===");

View File

@ -26,7 +26,8 @@ pub enum HuggingFaceModel {
BgeLarge,
/// sentence-transformers/all-MiniLM-L6-v2 - 384 dimensions, fast
AllMiniLm,
/// sentence-transformers/all-mpnet-base-v2 - 768 dimensions, strong baseline
/// sentence-transformers/all-mpnet-base-v2 - 768 dimensions, strong
/// baseline
AllMpnet,
/// Custom model with model ID and dimensions
Custom(String, usize),
@ -101,9 +102,13 @@ enum HFResponse {
#[cfg(feature = "huggingface-provider")]
impl HuggingFaceProvider {
const BASE_URL: &'static str = "https://api-inference.huggingface.co/pipeline/feature-extraction";
const BASE_URL: &'static str =
"https://api-inference.huggingface.co/pipeline/feature-extraction";
pub fn new(api_key: impl Into<String>, model: HuggingFaceModel) -> Result<Self, EmbeddingError> {
pub fn new(
api_key: impl Into<String>,
model: HuggingFaceModel,
) -> Result<Self, EmbeddingError> {
let api_key = api_key.into();
if api_key.is_empty() {
return Err(EmbeddingError::ConfigError(