From 9864f88c14ac030f0fa914fd46c2cf4c1a412fc0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jesu=CC=81s=20Pe=CC=81rez?= Date: Sat, 24 Jan 2026 02:25:05 +0000 Subject: [PATCH] chore: add huggingface provider and examples --- .../examples/huggingface_usage.rs | 24 +++++++++---------- .../src/providers/huggingface.rs | 11 ++++++--- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/crates/stratum-embeddings/examples/huggingface_usage.rs b/crates/stratum-embeddings/examples/huggingface_usage.rs index 6763c68..033ba02 100644 --- a/crates/stratum-embeddings/examples/huggingface_usage.rs +++ b/crates/stratum-embeddings/examples/huggingface_usage.rs @@ -1,8 +1,6 @@ use std::time::Duration; -use stratum_embeddings::{ - EmbeddingOptions, HuggingFaceModel, HuggingFaceProvider, MemoryCache, -}; +use stratum_embeddings::{EmbeddingOptions, HuggingFaceModel, HuggingFaceProvider, MemoryCache}; use tracing::info; #[tokio::main] @@ -48,10 +46,7 @@ async fn main() -> Result<(), Box> { let provider = HuggingFaceProvider::new(api_key, custom_model)?; let embedding = provider.embed(text, &options).await?; - info!( - "Custom model embedding: {} dimensions", - embedding.len() - ); + info!("Custom model embedding: {} dimensions", embedding.len()); // Example 4: Batch embeddings (sequential requests to HF API) info!("\n4. Batch embedding (sequential API calls)"); @@ -74,9 +69,8 @@ async fn main() -> Result<(), Box> { // Example 5: Using with cache info!("\n5. Demonstrating cache effectiveness"); let cache = MemoryCache::new(1000, Duration::from_secs(300)); - let service = stratum_embeddings::EmbeddingService::new( - HuggingFaceProvider::bge_small()? - ).with_cache(cache); + let service = stratum_embeddings::EmbeddingService::new(HuggingFaceProvider::bge_small()?) + .with_cache(cache); let cached_options = EmbeddingOptions::default_with_cache(); @@ -91,7 +85,10 @@ async fn main() -> Result<(), Box> { let _ = service.embed(text, &cached_options).await?; let second_duration = start.elapsed(); info!("Second call (cache hit): {:?}", second_duration); - info!("Speedup: {:.2}x", first_duration.as_secs_f64() / second_duration.as_secs_f64()); + info!( + "Speedup: {:.2}x", + first_duration.as_secs_f64() / second_duration.as_secs_f64() + ); info!("Cache size: {}", service.cache_size()); // Example 6: Normalized embeddings for similarity search @@ -117,7 +114,10 @@ async fn main() -> Result<(), Box> { info!("Query: '{}'", query); info!("Similarity with doc1 ('{}'): {:.4}", doc1, sim1); info!("Similarity with doc2 ('{}'): {:.4}", doc2, sim2); - info!("Most similar: {}", if sim1 > sim2 { "doc1" } else { "doc2" }); + info!( + "Most similar: {}", + if sim1 > sim2 { "doc1" } else { "doc2" } + ); info!("\n=== Demo Complete ==="); diff --git a/crates/stratum-embeddings/src/providers/huggingface.rs b/crates/stratum-embeddings/src/providers/huggingface.rs index fff8235..ed3bf7b 100644 --- a/crates/stratum-embeddings/src/providers/huggingface.rs +++ b/crates/stratum-embeddings/src/providers/huggingface.rs @@ -26,7 +26,8 @@ pub enum HuggingFaceModel { BgeLarge, /// sentence-transformers/all-MiniLM-L6-v2 - 384 dimensions, fast AllMiniLm, - /// sentence-transformers/all-mpnet-base-v2 - 768 dimensions, strong baseline + /// sentence-transformers/all-mpnet-base-v2 - 768 dimensions, strong + /// baseline AllMpnet, /// Custom model with model ID and dimensions Custom(String, usize), @@ -101,9 +102,13 @@ enum HFResponse { #[cfg(feature = "huggingface-provider")] impl HuggingFaceProvider { - const BASE_URL: &'static str = "https://api-inference.huggingface.co/pipeline/feature-extraction"; + const BASE_URL: &'static str = + "https://api-inference.huggingface.co/pipeline/feature-extraction"; - pub fn new(api_key: impl Into, model: HuggingFaceModel) -> Result { + pub fn new( + api_key: impl Into, + model: HuggingFaceModel, + ) -> Result { let api_key = api_key.into(); if api_key.is_empty() { return Err(EmbeddingError::ConfigError(