chore: add huggingface provider and examples
This commit is contained in:
parent
0ae853c2fa
commit
9864f88c14
@ -1,8 +1,6 @@
|
|||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
use stratum_embeddings::{
|
use stratum_embeddings::{EmbeddingOptions, HuggingFaceModel, HuggingFaceProvider, MemoryCache};
|
||||||
EmbeddingOptions, HuggingFaceModel, HuggingFaceProvider, MemoryCache,
|
|
||||||
};
|
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
@ -48,10 +46,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
let provider = HuggingFaceProvider::new(api_key, custom_model)?;
|
let provider = HuggingFaceProvider::new(api_key, custom_model)?;
|
||||||
let embedding = provider.embed(text, &options).await?;
|
let embedding = provider.embed(text, &options).await?;
|
||||||
|
|
||||||
info!(
|
info!("Custom model embedding: {} dimensions", embedding.len());
|
||||||
"Custom model embedding: {} dimensions",
|
|
||||||
embedding.len()
|
|
||||||
);
|
|
||||||
|
|
||||||
// Example 4: Batch embeddings (sequential requests to HF API)
|
// Example 4: Batch embeddings (sequential requests to HF API)
|
||||||
info!("\n4. Batch embedding (sequential API calls)");
|
info!("\n4. Batch embedding (sequential API calls)");
|
||||||
@ -74,9 +69,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
// Example 5: Using with cache
|
// Example 5: Using with cache
|
||||||
info!("\n5. Demonstrating cache effectiveness");
|
info!("\n5. Demonstrating cache effectiveness");
|
||||||
let cache = MemoryCache::new(1000, Duration::from_secs(300));
|
let cache = MemoryCache::new(1000, Duration::from_secs(300));
|
||||||
let service = stratum_embeddings::EmbeddingService::new(
|
let service = stratum_embeddings::EmbeddingService::new(HuggingFaceProvider::bge_small()?)
|
||||||
HuggingFaceProvider::bge_small()?
|
.with_cache(cache);
|
||||||
).with_cache(cache);
|
|
||||||
|
|
||||||
let cached_options = EmbeddingOptions::default_with_cache();
|
let cached_options = EmbeddingOptions::default_with_cache();
|
||||||
|
|
||||||
@ -91,7 +85,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
let _ = service.embed(text, &cached_options).await?;
|
let _ = service.embed(text, &cached_options).await?;
|
||||||
let second_duration = start.elapsed();
|
let second_duration = start.elapsed();
|
||||||
info!("Second call (cache hit): {:?}", second_duration);
|
info!("Second call (cache hit): {:?}", second_duration);
|
||||||
info!("Speedup: {:.2}x", first_duration.as_secs_f64() / second_duration.as_secs_f64());
|
info!(
|
||||||
|
"Speedup: {:.2}x",
|
||||||
|
first_duration.as_secs_f64() / second_duration.as_secs_f64()
|
||||||
|
);
|
||||||
info!("Cache size: {}", service.cache_size());
|
info!("Cache size: {}", service.cache_size());
|
||||||
|
|
||||||
// Example 6: Normalized embeddings for similarity search
|
// Example 6: Normalized embeddings for similarity search
|
||||||
@ -117,7 +114,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
info!("Query: '{}'", query);
|
info!("Query: '{}'", query);
|
||||||
info!("Similarity with doc1 ('{}'): {:.4}", doc1, sim1);
|
info!("Similarity with doc1 ('{}'): {:.4}", doc1, sim1);
|
||||||
info!("Similarity with doc2 ('{}'): {:.4}", doc2, sim2);
|
info!("Similarity with doc2 ('{}'): {:.4}", doc2, sim2);
|
||||||
info!("Most similar: {}", if sim1 > sim2 { "doc1" } else { "doc2" });
|
info!(
|
||||||
|
"Most similar: {}",
|
||||||
|
if sim1 > sim2 { "doc1" } else { "doc2" }
|
||||||
|
);
|
||||||
|
|
||||||
info!("\n=== Demo Complete ===");
|
info!("\n=== Demo Complete ===");
|
||||||
|
|
||||||
|
|||||||
@ -26,7 +26,8 @@ pub enum HuggingFaceModel {
|
|||||||
BgeLarge,
|
BgeLarge,
|
||||||
/// sentence-transformers/all-MiniLM-L6-v2 - 384 dimensions, fast
|
/// sentence-transformers/all-MiniLM-L6-v2 - 384 dimensions, fast
|
||||||
AllMiniLm,
|
AllMiniLm,
|
||||||
/// sentence-transformers/all-mpnet-base-v2 - 768 dimensions, strong baseline
|
/// sentence-transformers/all-mpnet-base-v2 - 768 dimensions, strong
|
||||||
|
/// baseline
|
||||||
AllMpnet,
|
AllMpnet,
|
||||||
/// Custom model with model ID and dimensions
|
/// Custom model with model ID and dimensions
|
||||||
Custom(String, usize),
|
Custom(String, usize),
|
||||||
@ -101,9 +102,13 @@ enum HFResponse {
|
|||||||
|
|
||||||
#[cfg(feature = "huggingface-provider")]
|
#[cfg(feature = "huggingface-provider")]
|
||||||
impl HuggingFaceProvider {
|
impl HuggingFaceProvider {
|
||||||
const BASE_URL: &'static str = "https://api-inference.huggingface.co/pipeline/feature-extraction";
|
const BASE_URL: &'static str =
|
||||||
|
"https://api-inference.huggingface.co/pipeline/feature-extraction";
|
||||||
|
|
||||||
pub fn new(api_key: impl Into<String>, model: HuggingFaceModel) -> Result<Self, EmbeddingError> {
|
pub fn new(
|
||||||
|
api_key: impl Into<String>,
|
||||||
|
model: HuggingFaceModel,
|
||||||
|
) -> Result<Self, EmbeddingError> {
|
||||||
let api_key = api_key.into();
|
let api_key = api_key.into();
|
||||||
if api_key.is_empty() {
|
if api_key.is_empty() {
|
||||||
return Err(EmbeddingError::ConfigError(
|
return Err(EmbeddingError::ConfigError(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user