126 lines
4.4 KiB
Rust
126 lines
4.4 KiB
Rust
use std::time::Duration;
|
|
|
|
use stratum_embeddings::{
|
|
EmbeddingOptions, HuggingFaceModel, HuggingFaceProvider, MemoryCache,
|
|
};
|
|
use tracing::info;
|
|
|
|
#[tokio::main]
|
|
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|
tracing_subscriber::fmt::init();
|
|
|
|
info!("=== HuggingFace Embedding Provider Demo ===");
|
|
|
|
// Example 1: Using predefined model (bge-small)
|
|
info!("\n1. Using predefined BGE-small model (384 dimensions)");
|
|
let provider = HuggingFaceProvider::bge_small()?;
|
|
let options = EmbeddingOptions::default_with_cache();
|
|
|
|
let text = "HuggingFace provides free inference API for embedding models";
|
|
let embedding = provider.embed(text, &options).await?;
|
|
|
|
info!(
|
|
"Generated embedding with {} dimensions from BGE-small",
|
|
embedding.len()
|
|
);
|
|
info!("First 5 values: {:?}", &embedding[..5]);
|
|
|
|
// Example 2: Using different model size
|
|
info!("\n2. Using BGE-base model (768 dimensions)");
|
|
let provider = HuggingFaceProvider::bge_base()?;
|
|
let embedding = provider.embed(text, &options).await?;
|
|
|
|
info!(
|
|
"Generated embedding with {} dimensions from BGE-base",
|
|
embedding.len()
|
|
);
|
|
|
|
// Example 3: Using custom model
|
|
info!("\n3. Using custom model");
|
|
let api_key = std::env::var("HUGGINGFACE_API_KEY")
|
|
.or_else(|_| std::env::var("HF_TOKEN"))
|
|
.expect("Set HUGGINGFACE_API_KEY or HF_TOKEN");
|
|
|
|
let custom_model = HuggingFaceModel::Custom(
|
|
"sentence-transformers/paraphrase-MiniLM-L6-v2".to_string(),
|
|
384,
|
|
);
|
|
let provider = HuggingFaceProvider::new(api_key, custom_model)?;
|
|
let embedding = provider.embed(text, &options).await?;
|
|
|
|
info!(
|
|
"Custom model embedding: {} dimensions",
|
|
embedding.len()
|
|
);
|
|
|
|
// Example 4: Batch embeddings (sequential requests to HF API)
|
|
info!("\n4. Batch embedding (sequential API calls)");
|
|
let provider = HuggingFaceProvider::all_minilm()?;
|
|
|
|
let texts = vec![
|
|
"First document about embeddings",
|
|
"Second document about transformers",
|
|
"Third document about NLP",
|
|
];
|
|
|
|
let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
|
|
let result = provider.embed_batch(&text_refs, &options).await?;
|
|
|
|
info!("Embedded {} texts", result.embeddings.len());
|
|
for (i, emb) in result.embeddings.iter().enumerate() {
|
|
info!(" Text {}: {} dimensions", i + 1, emb.len());
|
|
}
|
|
|
|
// Example 5: Using with cache
|
|
info!("\n5. Demonstrating cache effectiveness");
|
|
let cache = MemoryCache::new(1000, Duration::from_secs(300));
|
|
let service = stratum_embeddings::EmbeddingService::new(
|
|
HuggingFaceProvider::bge_small()?
|
|
).with_cache(cache);
|
|
|
|
let cached_options = EmbeddingOptions::default_with_cache();
|
|
|
|
// First call - cache miss
|
|
let start = std::time::Instant::now();
|
|
let _ = service.embed(text, &cached_options).await?;
|
|
let first_duration = start.elapsed();
|
|
info!("First call (cache miss): {:?}", first_duration);
|
|
|
|
// Second call - cache hit
|
|
let start = std::time::Instant::now();
|
|
let _ = service.embed(text, &cached_options).await?;
|
|
let second_duration = start.elapsed();
|
|
info!("Second call (cache hit): {:?}", second_duration);
|
|
info!("Speedup: {:.2}x", first_duration.as_secs_f64() / second_duration.as_secs_f64());
|
|
info!("Cache size: {}", service.cache_size());
|
|
|
|
// Example 6: Normalized embeddings for similarity search
|
|
info!("\n6. Normalized embeddings for similarity");
|
|
let provider = HuggingFaceProvider::bge_small()?;
|
|
let normalize_options = EmbeddingOptions {
|
|
normalize: true,
|
|
truncate: true,
|
|
use_cache: true,
|
|
};
|
|
|
|
let query = "machine learning embeddings";
|
|
let doc1 = "neural network embeddings for NLP";
|
|
let doc2 = "cooking recipes and ingredients";
|
|
|
|
let query_emb = provider.embed(query, &normalize_options).await?;
|
|
let doc1_emb = provider.embed(doc1, &normalize_options).await?;
|
|
let doc2_emb = provider.embed(doc2, &normalize_options).await?;
|
|
|
|
let sim1 = stratum_embeddings::cosine_similarity(&query_emb, &doc1_emb);
|
|
let sim2 = stratum_embeddings::cosine_similarity(&query_emb, &doc2_emb);
|
|
|
|
info!("Query: '{}'", query);
|
|
info!("Similarity with doc1 ('{}'): {:.4}", doc1, sim1);
|
|
info!("Similarity with doc2 ('{}'): {:.4}", doc2, sim2);
|
|
info!("Most similar: {}", if sim1 > sim2 { "doc1" } else { "doc2" });
|
|
|
|
info!("\n=== Demo Complete ===");
|
|
|
|
Ok(())
|
|
}
|