chore: add huggingface provider and examples
This commit is contained in:
parent
0ae853c2fa
commit
9864f88c14
@ -1,8 +1,6 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use stratum_embeddings::{
|
||||
EmbeddingOptions, HuggingFaceModel, HuggingFaceProvider, MemoryCache,
|
||||
};
|
||||
use stratum_embeddings::{EmbeddingOptions, HuggingFaceModel, HuggingFaceProvider, MemoryCache};
|
||||
use tracing::info;
|
||||
|
||||
#[tokio::main]
|
||||
@ -48,10 +46,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let provider = HuggingFaceProvider::new(api_key, custom_model)?;
|
||||
let embedding = provider.embed(text, &options).await?;
|
||||
|
||||
info!(
|
||||
"Custom model embedding: {} dimensions",
|
||||
embedding.len()
|
||||
);
|
||||
info!("Custom model embedding: {} dimensions", embedding.len());
|
||||
|
||||
// Example 4: Batch embeddings (sequential requests to HF API)
|
||||
info!("\n4. Batch embedding (sequential API calls)");
|
||||
@ -74,9 +69,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Example 5: Using with cache
|
||||
info!("\n5. Demonstrating cache effectiveness");
|
||||
let cache = MemoryCache::new(1000, Duration::from_secs(300));
|
||||
let service = stratum_embeddings::EmbeddingService::new(
|
||||
HuggingFaceProvider::bge_small()?
|
||||
).with_cache(cache);
|
||||
let service = stratum_embeddings::EmbeddingService::new(HuggingFaceProvider::bge_small()?)
|
||||
.with_cache(cache);
|
||||
|
||||
let cached_options = EmbeddingOptions::default_with_cache();
|
||||
|
||||
@ -91,7 +85,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _ = service.embed(text, &cached_options).await?;
|
||||
let second_duration = start.elapsed();
|
||||
info!("Second call (cache hit): {:?}", second_duration);
|
||||
info!("Speedup: {:.2}x", first_duration.as_secs_f64() / second_duration.as_secs_f64());
|
||||
info!(
|
||||
"Speedup: {:.2}x",
|
||||
first_duration.as_secs_f64() / second_duration.as_secs_f64()
|
||||
);
|
||||
info!("Cache size: {}", service.cache_size());
|
||||
|
||||
// Example 6: Normalized embeddings for similarity search
|
||||
@ -117,7 +114,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
info!("Query: '{}'", query);
|
||||
info!("Similarity with doc1 ('{}'): {:.4}", doc1, sim1);
|
||||
info!("Similarity with doc2 ('{}'): {:.4}", doc2, sim2);
|
||||
info!("Most similar: {}", if sim1 > sim2 { "doc1" } else { "doc2" });
|
||||
info!(
|
||||
"Most similar: {}",
|
||||
if sim1 > sim2 { "doc1" } else { "doc2" }
|
||||
);
|
||||
|
||||
info!("\n=== Demo Complete ===");
|
||||
|
||||
|
||||
@ -26,7 +26,8 @@ pub enum HuggingFaceModel {
|
||||
BgeLarge,
|
||||
/// sentence-transformers/all-MiniLM-L6-v2 - 384 dimensions, fast
|
||||
AllMiniLm,
|
||||
/// sentence-transformers/all-mpnet-base-v2 - 768 dimensions, strong baseline
|
||||
/// sentence-transformers/all-mpnet-base-v2 - 768 dimensions, strong
|
||||
/// baseline
|
||||
AllMpnet,
|
||||
/// Custom model with model ID and dimensions
|
||||
Custom(String, usize),
|
||||
@ -101,9 +102,13 @@ enum HFResponse {
|
||||
|
||||
#[cfg(feature = "huggingface-provider")]
|
||||
impl HuggingFaceProvider {
|
||||
const BASE_URL: &'static str = "https://api-inference.huggingface.co/pipeline/feature-extraction";
|
||||
const BASE_URL: &'static str =
|
||||
"https://api-inference.huggingface.co/pipeline/feature-extraction";
|
||||
|
||||
pub fn new(api_key: impl Into<String>, model: HuggingFaceModel) -> Result<Self, EmbeddingError> {
|
||||
pub fn new(
|
||||
api_key: impl Into<String>,
|
||||
model: HuggingFaceModel,
|
||||
) -> Result<Self, EmbeddingError> {
|
||||
let api_key = api_key.into();
|
||||
if api_key.is_empty() {
|
||||
return Err(EmbeddingError::ConfigError(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user