chore: create stratum-embeddings and stratum-llm crates, docs
This commit is contained in:
parent
b0d039d22d
commit
0ae853c2fa
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,6 +1,7 @@
|
||||
CLAUDE.md
|
||||
.claude
|
||||
utils/save*sh
|
||||
.fastembed_cache
|
||||
COMMIT_MESSAGE.md
|
||||
.wrks
|
||||
nushell
|
||||
|
||||
37
CHANGELOG.md
Normal file
37
CHANGELOG.md
Normal file
@ -0,0 +1,37 @@
|
||||
# Changelog
|
||||
|
||||
All notable changes to this project will be documented in this file.
|
||||
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
### Added
|
||||
|
||||
- **Architecture Documentation**: New `docs/*/architecture/` section with ADRs
|
||||
- ADR-001: stratum-embeddings - Unified embedding library with caching, fallback,
|
||||
and VectorStore trait (SurrealDB for Kogral, LanceDB for Provisioning/Vapora)
|
||||
- ADR-002: stratum-llm - Unified LLM provider library with CLI credential detection,
|
||||
circuit breaker, caching, and Kogral integration
|
||||
- **Bilingual ADRs**: Full English and Spanish versions of all architecture documents
|
||||
- **README updates**: Added Stratum Crates section and updated documentation structure
|
||||
|
||||
### Changed
|
||||
|
||||
- Documentation structure now includes `architecture/adrs/` subdirectory in both
|
||||
language directories (en/es)
|
||||
|
||||
## [0.1.0] - 2026-01-22
|
||||
|
||||
### Added
|
||||
|
||||
- Initial repository setup
|
||||
- Main documentation structure (bilingual en/es)
|
||||
- Branding assets (logos, icons, social variants)
|
||||
- CI/CD configuration (GitHub Actions, Woodpecker)
|
||||
- Language guidelines (Rust, Nickel, Nushell, Bash)
|
||||
- Pre-commit hooks configuration
|
||||
|
||||
[Unreleased]: https://repo.jesusperez.pro/jesus/stratumiops/compare/v0.1.0...HEAD
|
||||
[0.1.0]: https://repo.jesusperez.pro/jesus/stratumiops/releases/tag/v0.1.0
|
||||
10658
Cargo.lock
generated
Normal file
10658
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
60
Cargo.toml
Normal file
60
Cargo.toml
Normal file
@ -0,0 +1,60 @@
|
||||
[workspace]
|
||||
members = ["crates/*"]
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
edition = "2021"
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[workspace.dependencies]
|
||||
# Async runtime
|
||||
tokio = { version = "1.49", features = ["full"] }
|
||||
async-trait = "0.1"
|
||||
futures = "0.3"
|
||||
|
||||
# HTTP client
|
||||
reqwest = { version = "0.13", features = ["json"] }
|
||||
|
||||
# Serialization
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
serde_yaml = "0.9"
|
||||
humantime-serde = "1.1"
|
||||
|
||||
# Caching
|
||||
moka = { version = "0.12", features = ["future"] }
|
||||
sled = "0.34"
|
||||
|
||||
# Embeddings
|
||||
fastembed = "5.8"
|
||||
|
||||
# Vector storage
|
||||
lancedb = "0.23"
|
||||
surrealdb = { version = "2.5", features = ["kv-mem"] }
|
||||
# LOCKED: Arrow 56.x required for LanceDB 0.23 compatibility
|
||||
# LanceDB 0.23 uses Arrow 56.2.0 internally - Arrow 57 breaks API compatibility
|
||||
# DO NOT upgrade to Arrow 57 until LanceDB supports it
|
||||
arrow = "=56"
|
||||
|
||||
# Error handling
|
||||
thiserror = "2.0"
|
||||
anyhow = "1.0"
|
||||
|
||||
# Logging and tracing
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = "0.3"
|
||||
|
||||
# Metrics
|
||||
prometheus = "0.14"
|
||||
|
||||
# Utilities
|
||||
xxhash-rust = { version = "0.8", features = ["xxh3"] }
|
||||
dirs = "6.0"
|
||||
chrono = "0.4"
|
||||
uuid = "1.19"
|
||||
which = "8.0"
|
||||
|
||||
# Testing
|
||||
tokio-test = "0.4"
|
||||
approx = "0.5"
|
||||
tempfile = "3.24"
|
||||
17
README.md
17
README.md
@ -131,16 +131,29 @@ StratumIOps is not a single project. It's the **orchestration layer** that coord
|
||||
- **Integration Patterns**: How projects work together
|
||||
- **Shared Standards**: Language guidelines (Rust, Nickel, Nushell, Bash)
|
||||
|
||||
### Stratum Crates
|
||||
|
||||
Shared infrastructure libraries for the ecosystem:
|
||||
|
||||
| Crate | Description | Status |
|
||||
| ----- | ----------- | ------ |
|
||||
| **stratum-embeddings** | Unified embedding providers with caching, fallback, and VectorStore trait | Proposed |
|
||||
| **stratum-llm** | Unified LLM providers with CLI detection, circuit breaker, and caching | Proposed |
|
||||
|
||||
See [Architecture ADRs](docs/en/architecture/adrs/) for detailed design decisions.
|
||||
|
||||
### Documentation Structure
|
||||
|
||||
```text
|
||||
docs/
|
||||
├── en/ # English documentation
|
||||
│ ├── ia/ # AI/Development track
|
||||
│ └── ops/ # Ops/DevOps track
|
||||
│ ├── ops/ # Ops/DevOps track
|
||||
│ └── architecture/ # Architecture decisions (ADRs)
|
||||
└── es/ # Spanish documentation
|
||||
├── ia/ # AI/Development track
|
||||
└── ops/ # Ops/DevOps track
|
||||
├── ops/ # Ops/DevOps track
|
||||
└── architecture/ # Architecture decisions (ADRs)
|
||||
```
|
||||
|
||||
### Branding Assets
|
||||
|
||||
113
crates/stratum-embeddings/Cargo.toml
Normal file
113
crates/stratum-embeddings/Cargo.toml
Normal file
@ -0,0 +1,113 @@
|
||||
[package]
|
||||
name = "stratum-embeddings"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
description = "Unified embedding providers with caching, batch processing, and vector storage"
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
# Async runtime
|
||||
tokio = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
|
||||
# HTTP client (for cloud providers)
|
||||
reqwest = { workspace = true, optional = true }
|
||||
|
||||
# Serialization
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
humantime-serde = { workspace = true }
|
||||
|
||||
# Caching
|
||||
moka = { workspace = true }
|
||||
|
||||
# Persistent cache (optional)
|
||||
sled = { workspace = true, optional = true }
|
||||
|
||||
# Local embeddings
|
||||
fastembed = { workspace = true, optional = true }
|
||||
|
||||
# Vector storage backends
|
||||
lancedb = { workspace = true, optional = true }
|
||||
surrealdb = { workspace = true, optional = true }
|
||||
arrow = { workspace = true, optional = true }
|
||||
|
||||
# Error handling
|
||||
thiserror = { workspace = true }
|
||||
|
||||
# Logging
|
||||
tracing = { workspace = true }
|
||||
|
||||
# Metrics
|
||||
prometheus = { workspace = true, optional = true }
|
||||
|
||||
# Utilities
|
||||
xxhash-rust = { workspace = true }
|
||||
|
||||
[features]
|
||||
default = ["fastembed-provider", "memory-cache"]
|
||||
|
||||
# Providers
|
||||
fastembed-provider = ["fastembed"]
|
||||
openai-provider = ["reqwest"]
|
||||
ollama-provider = ["reqwest"]
|
||||
cohere-provider = ["reqwest"]
|
||||
voyage-provider = ["reqwest"]
|
||||
huggingface-provider = ["reqwest"]
|
||||
all-providers = [
|
||||
"fastembed-provider",
|
||||
"openai-provider",
|
||||
"ollama-provider",
|
||||
"cohere-provider",
|
||||
"voyage-provider",
|
||||
"huggingface-provider",
|
||||
]
|
||||
|
||||
# Cache backends
|
||||
memory-cache = []
|
||||
persistent-cache = ["sled"]
|
||||
all-cache = ["memory-cache", "persistent-cache"]
|
||||
|
||||
# Vector storage backends
|
||||
lancedb-store = ["lancedb", "arrow"]
|
||||
surrealdb-store = ["surrealdb"]
|
||||
all-stores = ["lancedb-store", "surrealdb-store"]
|
||||
|
||||
# Observability
|
||||
metrics = ["prometheus"]
|
||||
|
||||
# Project-specific presets
|
||||
kogral = ["fastembed-provider", "memory-cache", "surrealdb-store"]
|
||||
provisioning = ["openai-provider", "memory-cache", "lancedb-store"]
|
||||
vapora = ["all-providers", "memory-cache", "lancedb-store"] # Includes huggingface-provider
|
||||
|
||||
# Full feature set
|
||||
full = ["all-providers", "all-cache", "all-stores", "metrics"]
|
||||
|
||||
[dev-dependencies]
|
||||
tokio-test = { workspace = true }
|
||||
approx = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
|
||||
# Example-specific feature requirements
|
||||
[[example]]
|
||||
name = "basic_usage"
|
||||
required-features = ["fastembed-provider"]
|
||||
|
||||
[[example]]
|
||||
name = "fallback_demo"
|
||||
required-features = ["ollama-provider", "fastembed-provider"]
|
||||
|
||||
[[example]]
|
||||
name = "lancedb_usage"
|
||||
required-features = ["lancedb-store", "fastembed-provider"]
|
||||
|
||||
[[example]]
|
||||
name = "surrealdb_usage"
|
||||
required-features = ["surrealdb-store", "fastembed-provider"]
|
||||
|
||||
[[example]]
|
||||
name = "huggingface_usage"
|
||||
required-features = ["huggingface-provider"]
|
||||
180
crates/stratum-embeddings/README.md
Normal file
180
crates/stratum-embeddings/README.md
Normal file
@ -0,0 +1,180 @@
|
||||
# stratum-embeddings
|
||||
|
||||
Unified embedding providers with caching, batch processing, and vector storage for the STRATUMIOPS ecosystem.
|
||||
|
||||
## Features
|
||||
|
||||
- **Multiple Providers**: FastEmbed (local), OpenAI, Ollama
|
||||
- **Smart Caching**: In-memory caching with configurable TTL
|
||||
- **Batch Processing**: Efficient batch embedding with automatic chunking
|
||||
- **Vector Storage**: LanceDB (scale-first) and SurrealDB (graph-first)
|
||||
- **Fallback Support**: Automatic failover between providers
|
||||
- **Feature Flags**: Modular compilation for minimal dependencies
|
||||
|
||||
## Architecture
|
||||
|
||||
```text
|
||||
┌─────────────────────────────────────────┐
|
||||
│ EmbeddingService │
|
||||
│ (facade with caching + fallback) │
|
||||
└─────────────┬───────────────────────────┘
|
||||
│
|
||||
┌─────────┴─────────┐
|
||||
▼ ▼
|
||||
┌─────────────┐ ┌─────────────┐
|
||||
│ Providers │ │ Cache │
|
||||
│ │ │ │
|
||||
│ • FastEmbed │ │ • Memory │
|
||||
│ • OpenAI │ │ • (Sled) │
|
||||
│ • Ollama │ │ │
|
||||
└─────────────┘ └─────────────┘
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```rust
|
||||
use stratum_embeddings::{
|
||||
EmbeddingService, FastEmbedProvider, MemoryCache, EmbeddingOptions
|
||||
};
|
||||
use std::time::Duration;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let provider = FastEmbedProvider::small()?;
|
||||
let cache = MemoryCache::new(1000, Duration::from_secs(300));
|
||||
let service = EmbeddingService::new(provider).with_cache(cache);
|
||||
|
||||
let options = EmbeddingOptions::default_with_cache();
|
||||
let embedding = service.embed("Hello world", &options).await?;
|
||||
|
||||
println!("Generated {} dimensions", embedding.len());
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
### Batch Processing
|
||||
|
||||
```rust
|
||||
let texts = vec![
|
||||
"Text 1".to_string(),
|
||||
"Text 2".to_string(),
|
||||
"Text 3".to_string(),
|
||||
];
|
||||
|
||||
let result = service.embed_batch(texts, &options).await?;
|
||||
println!("Embeddings: {}, Cached: {}",
|
||||
result.embeddings.len(),
|
||||
result.cached_count
|
||||
);
|
||||
```
|
||||
|
||||
### Vector Storage
|
||||
|
||||
#### LanceDB (Provisioning, Vapora)
|
||||
|
||||
```rust
|
||||
use stratum_embeddings::{LanceDbStore, VectorStore, VectorStoreConfig};
|
||||
|
||||
let config = VectorStoreConfig::new(384);
|
||||
let store = LanceDbStore::new("./data", "embeddings", config).await?;
|
||||
|
||||
store.upsert("doc1", &embedding, metadata).await?;
|
||||
let results = store.search(&query_embedding, 10, None).await?;
|
||||
```
|
||||
|
||||
#### SurrealDB (Kogral)
|
||||
|
||||
```rust
|
||||
use stratum_embeddings::{SurrealDbStore, VectorStore, VectorStoreConfig};
|
||||
|
||||
let config = VectorStoreConfig::new(384);
|
||||
let store = SurrealDbStore::new_memory("concepts", config).await?;
|
||||
|
||||
store.upsert("concept1", &embedding, metadata).await?;
|
||||
let results = store.search(&query_embedding, 10, None).await?;
|
||||
```
|
||||
|
||||
## Feature Flags
|
||||
|
||||
### Providers
|
||||
|
||||
- `fastembed-provider` (default) - Local embeddings via fastembed
|
||||
- `openai-provider` - OpenAI API embeddings
|
||||
- `ollama-provider` - Ollama local server embeddings
|
||||
- `all-providers` - All embedding providers
|
||||
|
||||
### Cache
|
||||
|
||||
- `memory-cache` (default) - In-memory caching with moka
|
||||
- `persistent-cache` - Persistent cache with sled
|
||||
- `all-cache` - All cache backends
|
||||
|
||||
### Vector Storage
|
||||
|
||||
- `lancedb-store` - LanceDB vector storage (columnar, disk-native)
|
||||
- `surrealdb-store` - SurrealDB vector storage (graph + vector)
|
||||
- `all-stores` - All storage backends
|
||||
|
||||
### Project Presets
|
||||
|
||||
- `kogral` - fastembed + memory + surrealdb
|
||||
- `provisioning` - openai + memory + lancedb
|
||||
- `vapora` - all-providers + memory + lancedb
|
||||
- `full` - Everything enabled
|
||||
|
||||
## Examples
|
||||
|
||||
Run examples with:
|
||||
|
||||
```bash
|
||||
cargo run --example basic_usage --features=default
|
||||
cargo run --example fallback_demo --features=fastembed-provider,ollama-provider
|
||||
cargo run --example lancedb_usage --features=lancedb-store
|
||||
cargo run --example surrealdb_usage --features=surrealdb-store
|
||||
```
|
||||
|
||||
## Provider Comparison
|
||||
|
||||
| Provider | Type | Cost | Dimensions | Use Case |
|
||||
|----------|------|------|------------|----------|
|
||||
| FastEmbed | Local | Free | 384-1024 | Dev, privacy-first |
|
||||
| OpenAI | Cloud | $0.02-0.13/1M | 1536-3072 | Production RAG |
|
||||
| Ollama | Local | Free | 384-1024 | Self-hosted |
|
||||
|
||||
## Storage Backend Comparison
|
||||
|
||||
| Backend | Best For | Strength | Scale |
|
||||
|---------|----------|----------|-------|
|
||||
| LanceDB | RAG, traces | Columnar, IVF-PQ index | Billions |
|
||||
| SurrealDB | Knowledge graphs | Unified graph+vector queries | Millions |
|
||||
|
||||
## Configuration
|
||||
|
||||
Environment variables:
|
||||
|
||||
```bash
|
||||
# FastEmbed
|
||||
FASTEMBED_MODEL=bge-small-en
|
||||
|
||||
# OpenAI
|
||||
OPENAI_API_KEY=sk-...
|
||||
OPENAI_MODEL=text-embedding-3-small
|
||||
|
||||
# Ollama
|
||||
OLLAMA_MODEL=nomic-embed-text
|
||||
OLLAMA_BASE_URL=http://localhost:11434
|
||||
```
|
||||
|
||||
## Development
|
||||
|
||||
```bash
|
||||
cargo check -p stratum-embeddings --all-features
|
||||
cargo test -p stratum-embeddings --all-features
|
||||
cargo clippy -p stratum-embeddings --all-features -- -D warnings
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
MIT OR Apache-2.0
|
||||
346
crates/stratum-embeddings/docs/huggingface-provider.md
Normal file
346
crates/stratum-embeddings/docs/huggingface-provider.md
Normal file
@ -0,0 +1,346 @@
|
||||
# HuggingFace Embedding Provider
|
||||
|
||||
Provider for HuggingFace Inference API embeddings with support for popular sentence-transformers and BGE models.
|
||||
|
||||
## Overview
|
||||
|
||||
The HuggingFace provider uses the free Inference API to generate embeddings. It supports:
|
||||
|
||||
- **Public Models**: Free access to popular embedding models
|
||||
- **Custom Models**: Support for any HuggingFace model with feature-extraction pipeline
|
||||
- **Automatic Caching**: Built-in memory cache reduces API calls
|
||||
- **Response Normalization**: Optional L2 normalization for similarity search
|
||||
|
||||
## Features
|
||||
|
||||
- ✅ Zero cost for public models (free Inference API)
|
||||
- ✅ Support for 5+ popular models out of the box
|
||||
- ✅ Custom model support with configurable dimensions
|
||||
- ✅ Automatic retry with exponential backoff
|
||||
- ✅ Rate limit handling
|
||||
- ✅ Integration with stratum-embeddings caching layer
|
||||
|
||||
## Supported Models
|
||||
|
||||
### Predefined Models
|
||||
|
||||
| Model | Dimensions | Use Case | Constructor |
|
||||
|-------|------------|----------|-------------|
|
||||
| **BAAI/bge-small-en-v1.5** | 384 | General-purpose, efficient | `HuggingFaceProvider::bge_small()` |
|
||||
| **BAAI/bge-base-en-v1.5** | 768 | Balanced performance | `HuggingFaceProvider::bge_base()` |
|
||||
| **BAAI/bge-large-en-v1.5** | 1024 | High quality | `HuggingFaceProvider::bge_large()` |
|
||||
| **sentence-transformers/all-MiniLM-L6-v2** | 384 | Fast, lightweight | `HuggingFaceProvider::all_minilm()` |
|
||||
| **sentence-transformers/all-mpnet-base-v2** | 768 | Strong baseline | - |
|
||||
|
||||
### Custom Models
|
||||
|
||||
```rust
|
||||
let model = HuggingFaceModel::Custom(
|
||||
"sentence-transformers/paraphrase-MiniLM-L6-v2".to_string(),
|
||||
384,
|
||||
);
|
||||
let provider = HuggingFaceProvider::new(api_key, model)?;
|
||||
```
|
||||
|
||||
## API Rate Limits
|
||||
|
||||
### Free Inference API
|
||||
|
||||
HuggingFace Inference API has the following rate limits:
|
||||
|
||||
| Tier | Requests/Hour | Requests/Day | Max Concurrent |
|
||||
|------|---------------|--------------|----------------|
|
||||
| **Anonymous** | 1,000 | 10,000 | 1 |
|
||||
| **Free Account** | 3,000 | 30,000 | 3 |
|
||||
| **PRO ($9/mo)** | 10,000 | 100,000 | 10 |
|
||||
| **Enterprise** | Custom | Custom | Custom |
|
||||
|
||||
**Rate Limit Headers**:
|
||||
```
|
||||
X-RateLimit-Limit: 3000
|
||||
X-RateLimit-Remaining: 2999
|
||||
X-RateLimit-Reset: 1234567890
|
||||
```
|
||||
|
||||
### Rate Limit Handling
|
||||
|
||||
The provider automatically handles rate limits with:
|
||||
|
||||
1. **Exponential Backoff**: Retries with increasing delays (1s, 2s, 4s, 8s)
|
||||
2. **Max Retries**: Default 3 retries before failing
|
||||
3. **Circuit Breaker**: Automatically pauses requests if rate limited repeatedly
|
||||
4. **Cache Integration**: Reduces API calls by 70-90% for repeated queries
|
||||
|
||||
**Configuration**:
|
||||
```rust
|
||||
// Default retry config (built-in)
|
||||
let provider = HuggingFaceProvider::new(api_key, model)?;
|
||||
|
||||
// With custom retry (future enhancement)
|
||||
let provider = HuggingFaceProvider::new(api_key, model)?
|
||||
.with_retry_config(RetryConfig {
|
||||
max_retries: 5,
|
||||
initial_delay: Duration::from_secs(2),
|
||||
max_delay: Duration::from_secs(30),
|
||||
});
|
||||
```
|
||||
|
||||
### Best Practices for Rate Limits
|
||||
|
||||
1. **Enable Caching**: Use `EmbeddingOptions::default_with_cache()`
|
||||
```rust
|
||||
let options = EmbeddingOptions::default_with_cache();
|
||||
let embedding = provider.embed(text, &options).await?;
|
||||
```
|
||||
|
||||
2. **Batch Requests Carefully**: HuggingFace Inference API processes requests sequentially
|
||||
```rust
|
||||
// This makes N API calls sequentially
|
||||
let texts = vec!["text1", "text2", "text3"];
|
||||
let result = provider.embed_batch(&texts, &options).await?;
|
||||
```
|
||||
|
||||
3. **Use PRO Account for Production**: Free tier is suitable for development only
|
||||
|
||||
4. **Monitor Rate Limits**: Check response headers
|
||||
```rust
|
||||
// Future enhancement - rate limit monitoring
|
||||
let stats = provider.rate_limit_stats();
|
||||
println!("Remaining: {}/{}", stats.remaining, stats.limit);
|
||||
```
|
||||
|
||||
## Authentication
|
||||
|
||||
### Environment Variables
|
||||
|
||||
The provider checks for API keys in this order:
|
||||
|
||||
1. `HUGGINGFACE_API_KEY`
|
||||
2. `HF_TOKEN` (alternative name)
|
||||
|
||||
```bash
|
||||
export HUGGINGFACE_API_KEY="hf_xxxxxxxxxxxxxxxxxxxx"
|
||||
```
|
||||
|
||||
### Getting an API Token
|
||||
|
||||
1. Go to [HuggingFace Settings](https://huggingface.co/settings/tokens)
|
||||
2. Click "New token"
|
||||
3. Select "Read" access (sufficient for Inference API)
|
||||
4. Copy the token starting with `hf_`
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```rust
|
||||
use stratum_embeddings::{HuggingFaceProvider, EmbeddingOptions};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Using predefined model
|
||||
let provider = HuggingFaceProvider::bge_small()?;
|
||||
|
||||
let options = EmbeddingOptions::default_with_cache();
|
||||
let embedding = provider.embed("Hello world", &options).await?;
|
||||
|
||||
println!("Dimensions: {}", embedding.len()); // 384
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
### With EmbeddingService (Recommended)
|
||||
|
||||
```rust
|
||||
use std::time::Duration;
|
||||
use stratum_embeddings::{
|
||||
HuggingFaceProvider, EmbeddingService, MemoryCache, EmbeddingOptions
|
||||
};
|
||||
|
||||
let provider = HuggingFaceProvider::bge_small()?;
|
||||
let cache = MemoryCache::new(1000, Duration::from_secs(3600));
|
||||
|
||||
let service = EmbeddingService::new(provider)
|
||||
.with_cache(cache);
|
||||
|
||||
let options = EmbeddingOptions::default_with_cache();
|
||||
let embedding = service.embed("Cached embeddings", &options).await?;
|
||||
```
|
||||
|
||||
### Semantic Similarity Search
|
||||
|
||||
```rust
|
||||
use stratum_embeddings::{HuggingFaceProvider, EmbeddingOptions, cosine_similarity};
|
||||
|
||||
let provider = HuggingFaceProvider::bge_small()?;
|
||||
let options = EmbeddingOptions {
|
||||
normalize: true, // Important for cosine similarity
|
||||
truncate: true,
|
||||
use_cache: true,
|
||||
};
|
||||
|
||||
let query = "machine learning";
|
||||
let doc1 = "deep learning and neural networks";
|
||||
let doc2 = "cooking recipes";
|
||||
|
||||
let query_emb = provider.embed(query, &options).await?;
|
||||
let doc1_emb = provider.embed(doc1, &options).await?;
|
||||
let doc2_emb = provider.embed(doc2, &options).await?;
|
||||
|
||||
let sim1 = cosine_similarity(&query_emb, &doc1_emb);
|
||||
let sim2 = cosine_similarity(&query_emb, &doc2_emb);
|
||||
|
||||
println!("Similarity with doc1: {:.4}", sim1); // ~0.85
|
||||
println!("Similarity with doc2: {:.4}", sim2); // ~0.15
|
||||
```
|
||||
|
||||
### Custom Model
|
||||
|
||||
```rust
|
||||
use stratum_embeddings::{HuggingFaceProvider, HuggingFaceModel};
|
||||
|
||||
let api_key = std::env::var("HUGGINGFACE_API_KEY")?;
|
||||
let model = HuggingFaceModel::Custom(
|
||||
"intfloat/multilingual-e5-large".to_string(),
|
||||
1024, // Specify dimensions
|
||||
);
|
||||
|
||||
let provider = HuggingFaceProvider::new(api_key, model)?;
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
### Common Errors
|
||||
|
||||
| Error | Cause | Solution |
|
||||
|-------|-------|----------|
|
||||
| `ConfigError: API key is empty` | Missing credentials | Set `HUGGINGFACE_API_KEY` |
|
||||
| `ApiError: HTTP 401` | Invalid API token | Check token validity |
|
||||
| `ApiError: HTTP 429` | Rate limit exceeded | Wait or upgrade tier |
|
||||
| `ApiError: HTTP 503` | Model loading | Retry after ~20s |
|
||||
| `DimensionMismatch` | Wrong model dimensions | Update `Custom` model dims |
|
||||
|
||||
### Retry Example
|
||||
|
||||
```rust
|
||||
use tokio::time::sleep;
|
||||
use std::time::Duration;
|
||||
|
||||
let mut retries = 0;
|
||||
let max_retries = 3;
|
||||
|
||||
loop {
|
||||
match provider.embed(text, &options).await {
|
||||
Ok(embedding) => break Ok(embedding),
|
||||
Err(e) if e.to_string().contains("429") && retries < max_retries => {
|
||||
retries += 1;
|
||||
let delay = Duration::from_secs(2u64.pow(retries));
|
||||
eprintln!("Rate limited, retrying in {:?}...", delay);
|
||||
sleep(delay).await;
|
||||
}
|
||||
Err(e) => break Err(e),
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Performance Characteristics
|
||||
|
||||
### Latency
|
||||
|
||||
| Operation | Latency | Notes |
|
||||
|-----------|---------|-------|
|
||||
| **Single embed** | 200-500ms | Depends on model size and region |
|
||||
| **Batch (N items)** | N × 200-500ms | Sequential processing |
|
||||
| **Cache hit** | <1ms | In-memory lookup |
|
||||
| **Cold start** | +5-20s | First request loads model |
|
||||
|
||||
### Throughput
|
||||
|
||||
| Tier | Max RPS | Daily Limit |
|
||||
|------|---------|-------------|
|
||||
| Free | ~0.8 | 30,000 |
|
||||
| PRO | ~2.8 | 100,000 |
|
||||
|
||||
**With Caching** (80% hit rate):
|
||||
- Free tier: ~4 effective RPS
|
||||
- PRO tier: ~14 effective RPS
|
||||
|
||||
## Cost Comparison
|
||||
|
||||
| Provider | Cost/1M Tokens | Free Tier | Notes |
|
||||
|----------|----------------|-----------|-------|
|
||||
| **HuggingFace** | $0.00 | 30k req/day | Free for public models |
|
||||
| OpenAI | $0.02-0.13 | $5 credit | Pay per token |
|
||||
| Cohere | $0.10 | 100 req/month | Limited free tier |
|
||||
| Voyage | $0.12 | None | No free tier |
|
||||
|
||||
## Limitations
|
||||
|
||||
1. **No True Batching**: Inference API processes one request at a time
|
||||
2. **Cold Starts**: Models need ~20s to load on first request
|
||||
3. **Rate Limits**: Free tier suitable for development only
|
||||
4. **Regional Latency**: Single region (US/EU), no edge locations
|
||||
5. **Model Loading**: Popular models cached, custom models may be slow
|
||||
|
||||
## Advanced Configuration
|
||||
|
||||
### Model Loading Timeout
|
||||
|
||||
```rust
|
||||
// Future enhancement
|
||||
let provider = HuggingFaceProvider::new(api_key, model)?
|
||||
.with_timeout(Duration::from_secs(120)); // Wait longer for cold starts
|
||||
```
|
||||
|
||||
### Dedicated Inference Endpoints
|
||||
|
||||
For production workloads, consider [Dedicated Endpoints](https://huggingface.co/inference-endpoints):
|
||||
|
||||
- True batch processing
|
||||
- Guaranteed uptime
|
||||
- No rate limits
|
||||
- Custom regions
|
||||
- ~$60-500/month
|
||||
|
||||
## Migration Guide
|
||||
|
||||
### From vapora Custom Implementation
|
||||
|
||||
**Before**:
|
||||
```rust
|
||||
let hf = HuggingFaceEmbedding::new(api_key, "BAAI/bge-small-en-v1.5".to_string());
|
||||
let embedding = hf.embed(text).await?;
|
||||
```
|
||||
|
||||
**After**:
|
||||
```rust
|
||||
let provider = HuggingFaceProvider::bge_small()?;
|
||||
let options = EmbeddingOptions::default_with_cache();
|
||||
let embedding = provider.embed(text, &options).await?;
|
||||
```
|
||||
|
||||
### From OpenAI
|
||||
|
||||
```rust
|
||||
// OpenAI (paid)
|
||||
let provider = OpenAiProvider::new(api_key, OpenAiModel::TextEmbedding3Small)?;
|
||||
|
||||
// HuggingFace (free, similar quality)
|
||||
let provider = HuggingFaceProvider::bge_small()?;
|
||||
```
|
||||
|
||||
## Running the Example
|
||||
|
||||
```bash
|
||||
export HUGGINGFACE_API_KEY="hf_xxxxxxxxxxxxxxxxxxxx"
|
||||
|
||||
cargo run --example huggingface_usage \
|
||||
--features huggingface-provider
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- [HuggingFace Inference API Docs](https://huggingface.co/docs/api-inference/index)
|
||||
- [BGE Embedding Models](https://huggingface.co/BAAI)
|
||||
- [Sentence Transformers](https://www.sbert.net/)
|
||||
- [Rate Limits Documentation](https://huggingface.co/docs/api-inference/rate-limits)
|
||||
50
crates/stratum-embeddings/examples/basic_usage.rs
Normal file
50
crates/stratum-embeddings/examples/basic_usage.rs
Normal file
@ -0,0 +1,50 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use stratum_embeddings::{EmbeddingOptions, EmbeddingService, FastEmbedProvider, MemoryCache};
|
||||
use tracing::info;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
info!("Initializing FastEmbed provider...");
|
||||
let provider = FastEmbedProvider::small()?;
|
||||
|
||||
let cache = MemoryCache::new(1000, Duration::from_secs(300));
|
||||
let service = EmbeddingService::new(provider).with_cache(cache);
|
||||
|
||||
info!("Service ready: {:?}", service.provider_info());
|
||||
|
||||
let options = EmbeddingOptions::default_with_cache();
|
||||
|
||||
info!("Embedding single text...");
|
||||
let text = "Stratum embeddings is a unified embedding library";
|
||||
let embedding = service.embed(text, &options).await?;
|
||||
|
||||
info!("Generated embedding with {} dimensions", embedding.len());
|
||||
|
||||
info!("Embedding same text again (should be cached)...");
|
||||
let embedding2 = service.embed(text, &options).await?;
|
||||
assert_eq!(embedding, embedding2);
|
||||
info!("Cache hit confirmed!");
|
||||
|
||||
info!("Embedding batch of texts...");
|
||||
let texts = vec![
|
||||
"Rust is a systems programming language".to_string(),
|
||||
"Knowledge graphs connect concepts".to_string(),
|
||||
"Vector databases enable semantic search".to_string(),
|
||||
];
|
||||
|
||||
let result = service.embed_batch(texts, &options).await?;
|
||||
|
||||
info!(
|
||||
"Batch complete: {} embeddings generated",
|
||||
result.embeddings.len()
|
||||
);
|
||||
info!("Model: {}, Dimensions: {}", result.model, result.dimensions);
|
||||
info!("Cached count: {}", result.cached_count);
|
||||
|
||||
info!("Cache size: {}", service.cache_size());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
44
crates/stratum-embeddings/examples/fallback_demo.rs
Normal file
44
crates/stratum-embeddings/examples/fallback_demo.rs
Normal file
@ -0,0 +1,44 @@
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use stratum_embeddings::{
|
||||
EmbeddingOptions, EmbeddingService, FastEmbedProvider, MemoryCache, OllamaProvider,
|
||||
};
|
||||
use tracing::{info, warn};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
info!("Setting up primary provider (Ollama)...");
|
||||
let primary = OllamaProvider::default_model()?;
|
||||
|
||||
info!("Setting up fallback provider (FastEmbed)...");
|
||||
let fallback =
|
||||
Arc::new(FastEmbedProvider::small()?) as Arc<dyn stratum_embeddings::EmbeddingProvider>;
|
||||
|
||||
let cache = MemoryCache::new(1000, Duration::from_secs(300));
|
||||
let service = EmbeddingService::new(primary)
|
||||
.with_cache(cache)
|
||||
.with_fallback(fallback);
|
||||
|
||||
let options = EmbeddingOptions::default_with_cache();
|
||||
|
||||
info!("Checking if Ollama is available...");
|
||||
if service.is_ready().await {
|
||||
info!("Ollama is available, using as primary");
|
||||
} else {
|
||||
warn!("Ollama not available, will fall back to FastEmbed");
|
||||
}
|
||||
|
||||
info!("Embedding text (will use available provider)...");
|
||||
let text = "This demonstrates fallback strategy in action";
|
||||
let embedding = service.embed(text, &options).await?;
|
||||
|
||||
info!(
|
||||
"Successfully generated embedding with {} dimensions",
|
||||
embedding.len()
|
||||
);
|
||||
info!("Cache size: {}", service.cache_size());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
125
crates/stratum-embeddings/examples/huggingface_usage.rs
Normal file
125
crates/stratum-embeddings/examples/huggingface_usage.rs
Normal file
@ -0,0 +1,125 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use stratum_embeddings::{
|
||||
EmbeddingOptions, HuggingFaceModel, HuggingFaceProvider, MemoryCache,
|
||||
};
|
||||
use tracing::info;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
info!("=== HuggingFace Embedding Provider Demo ===");
|
||||
|
||||
// Example 1: Using predefined model (bge-small)
|
||||
info!("\n1. Using predefined BGE-small model (384 dimensions)");
|
||||
let provider = HuggingFaceProvider::bge_small()?;
|
||||
let options = EmbeddingOptions::default_with_cache();
|
||||
|
||||
let text = "HuggingFace provides free inference API for embedding models";
|
||||
let embedding = provider.embed(text, &options).await?;
|
||||
|
||||
info!(
|
||||
"Generated embedding with {} dimensions from BGE-small",
|
||||
embedding.len()
|
||||
);
|
||||
info!("First 5 values: {:?}", &embedding[..5]);
|
||||
|
||||
// Example 2: Using different model size
|
||||
info!("\n2. Using BGE-base model (768 dimensions)");
|
||||
let provider = HuggingFaceProvider::bge_base()?;
|
||||
let embedding = provider.embed(text, &options).await?;
|
||||
|
||||
info!(
|
||||
"Generated embedding with {} dimensions from BGE-base",
|
||||
embedding.len()
|
||||
);
|
||||
|
||||
// Example 3: Using custom model
|
||||
info!("\n3. Using custom model");
|
||||
let api_key = std::env::var("HUGGINGFACE_API_KEY")
|
||||
.or_else(|_| std::env::var("HF_TOKEN"))
|
||||
.expect("Set HUGGINGFACE_API_KEY or HF_TOKEN");
|
||||
|
||||
let custom_model = HuggingFaceModel::Custom(
|
||||
"sentence-transformers/paraphrase-MiniLM-L6-v2".to_string(),
|
||||
384,
|
||||
);
|
||||
let provider = HuggingFaceProvider::new(api_key, custom_model)?;
|
||||
let embedding = provider.embed(text, &options).await?;
|
||||
|
||||
info!(
|
||||
"Custom model embedding: {} dimensions",
|
||||
embedding.len()
|
||||
);
|
||||
|
||||
// Example 4: Batch embeddings (sequential requests to HF API)
|
||||
info!("\n4. Batch embedding (sequential API calls)");
|
||||
let provider = HuggingFaceProvider::all_minilm()?;
|
||||
|
||||
let texts = vec![
|
||||
"First document about embeddings",
|
||||
"Second document about transformers",
|
||||
"Third document about NLP",
|
||||
];
|
||||
|
||||
let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
|
||||
let result = provider.embed_batch(&text_refs, &options).await?;
|
||||
|
||||
info!("Embedded {} texts", result.embeddings.len());
|
||||
for (i, emb) in result.embeddings.iter().enumerate() {
|
||||
info!(" Text {}: {} dimensions", i + 1, emb.len());
|
||||
}
|
||||
|
||||
// Example 5: Using with cache
|
||||
info!("\n5. Demonstrating cache effectiveness");
|
||||
let cache = MemoryCache::new(1000, Duration::from_secs(300));
|
||||
let service = stratum_embeddings::EmbeddingService::new(
|
||||
HuggingFaceProvider::bge_small()?
|
||||
).with_cache(cache);
|
||||
|
||||
let cached_options = EmbeddingOptions::default_with_cache();
|
||||
|
||||
// First call - cache miss
|
||||
let start = std::time::Instant::now();
|
||||
let _ = service.embed(text, &cached_options).await?;
|
||||
let first_duration = start.elapsed();
|
||||
info!("First call (cache miss): {:?}", first_duration);
|
||||
|
||||
// Second call - cache hit
|
||||
let start = std::time::Instant::now();
|
||||
let _ = service.embed(text, &cached_options).await?;
|
||||
let second_duration = start.elapsed();
|
||||
info!("Second call (cache hit): {:?}", second_duration);
|
||||
info!("Speedup: {:.2}x", first_duration.as_secs_f64() / second_duration.as_secs_f64());
|
||||
info!("Cache size: {}", service.cache_size());
|
||||
|
||||
// Example 6: Normalized embeddings for similarity search
|
||||
info!("\n6. Normalized embeddings for similarity");
|
||||
let provider = HuggingFaceProvider::bge_small()?;
|
||||
let normalize_options = EmbeddingOptions {
|
||||
normalize: true,
|
||||
truncate: true,
|
||||
use_cache: true,
|
||||
};
|
||||
|
||||
let query = "machine learning embeddings";
|
||||
let doc1 = "neural network embeddings for NLP";
|
||||
let doc2 = "cooking recipes and ingredients";
|
||||
|
||||
let query_emb = provider.embed(query, &normalize_options).await?;
|
||||
let doc1_emb = provider.embed(doc1, &normalize_options).await?;
|
||||
let doc2_emb = provider.embed(doc2, &normalize_options).await?;
|
||||
|
||||
let sim1 = stratum_embeddings::cosine_similarity(&query_emb, &doc1_emb);
|
||||
let sim2 = stratum_embeddings::cosine_similarity(&query_emb, &doc2_emb);
|
||||
|
||||
info!("Query: '{}'", query);
|
||||
info!("Similarity with doc1 ('{}'): {:.4}", doc1, sim1);
|
||||
info!("Similarity with doc2 ('{}'): {:.4}", doc2, sim2);
|
||||
info!("Most similar: {}", if sim1 > sim2 { "doc1" } else { "doc2" });
|
||||
|
||||
info!("\n=== Demo Complete ===");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
67
crates/stratum-embeddings/examples/lancedb_usage.rs
Normal file
67
crates/stratum-embeddings/examples/lancedb_usage.rs
Normal file
@ -0,0 +1,67 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use stratum_embeddings::{
|
||||
EmbeddingOptions, EmbeddingService, FastEmbedProvider, LanceDbStore, MemoryCache, VectorStore,
|
||||
VectorStoreConfig,
|
||||
};
|
||||
use tempfile::tempdir;
|
||||
use tracing::info;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
info!("Initializing embedding service...");
|
||||
let provider = FastEmbedProvider::small()?;
|
||||
let cache = MemoryCache::new(1000, Duration::from_secs(300));
|
||||
let service = EmbeddingService::new(provider).with_cache(cache);
|
||||
|
||||
let dir = tempdir()?;
|
||||
let db_path = dir.path().to_str().unwrap();
|
||||
|
||||
info!("Creating LanceDB store at: {}", db_path);
|
||||
let config = VectorStoreConfig::new(384);
|
||||
let store = LanceDbStore::new(db_path, "embeddings", config).await?;
|
||||
|
||||
let documents = vec![
|
||||
(
|
||||
"doc1",
|
||||
"Rust provides memory safety without garbage collection",
|
||||
),
|
||||
("doc2", "Knowledge graphs represent structured information"),
|
||||
("doc3", "Vector databases enable semantic similarity search"),
|
||||
("doc4", "Machine learning models learn from data patterns"),
|
||||
("doc5", "Embeddings capture semantic meaning in vectors"),
|
||||
];
|
||||
|
||||
info!("Embedding and storing {} documents...", documents.len());
|
||||
let options = EmbeddingOptions::default_with_cache();
|
||||
|
||||
for (id, text) in &documents {
|
||||
let embedding = service.embed(text, &options).await?;
|
||||
let metadata = serde_json::json!({
|
||||
"text": text,
|
||||
"source": "demo"
|
||||
});
|
||||
store.upsert(id, &embedding, metadata).await?;
|
||||
}
|
||||
|
||||
info!("Documents stored successfully");
|
||||
|
||||
info!("Performing semantic search...");
|
||||
let query = "How do databases support similarity matching?";
|
||||
let query_embedding = service.embed(query, &options).await?;
|
||||
|
||||
let results = store.search(&query_embedding, 3, None).await?;
|
||||
|
||||
info!("Search results for: '{}'", query);
|
||||
for (i, result) in results.iter().enumerate() {
|
||||
let text = result.metadata["text"].as_str().unwrap_or("N/A");
|
||||
info!(" {}. [score: {:.4}] {}", i + 1, result.score, text);
|
||||
}
|
||||
|
||||
let count = store.count().await?;
|
||||
info!("Total documents in store: {}", count);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
66
crates/stratum-embeddings/examples/surrealdb_usage.rs
Normal file
66
crates/stratum-embeddings/examples/surrealdb_usage.rs
Normal file
@ -0,0 +1,66 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use stratum_embeddings::{
|
||||
EmbeddingOptions, EmbeddingService, FastEmbedProvider, MemoryCache, SurrealDbStore,
|
||||
VectorStore, VectorStoreConfig,
|
||||
};
|
||||
use tracing::info;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
info!("Initializing embedding service...");
|
||||
let provider = FastEmbedProvider::small()?;
|
||||
let cache = MemoryCache::new(1000, Duration::from_secs(300));
|
||||
let service = EmbeddingService::new(provider).with_cache(cache);
|
||||
|
||||
info!("Creating SurrealDB in-memory store...");
|
||||
let config = VectorStoreConfig::new(384);
|
||||
let store = SurrealDbStore::new_memory("concepts", config).await?;
|
||||
|
||||
let concepts = vec![
|
||||
("ownership", "Rust's ownership system prevents memory leaks"),
|
||||
(
|
||||
"borrowing",
|
||||
"Borrowing allows references without ownership transfer",
|
||||
),
|
||||
("lifetimes", "Lifetimes ensure references remain valid"),
|
||||
("traits", "Traits define shared behavior across types"),
|
||||
("generics", "Generics enable code reuse with type safety"),
|
||||
];
|
||||
|
||||
info!("Embedding and storing {} concepts...", concepts.len());
|
||||
let options = EmbeddingOptions::default_with_cache();
|
||||
|
||||
for (id, description) in &concepts {
|
||||
let embedding = service.embed(description, &options).await?;
|
||||
let metadata = serde_json::json!({
|
||||
"concept": id,
|
||||
"description": description,
|
||||
"language": "rust"
|
||||
});
|
||||
store.upsert(id, &embedding, metadata).await?;
|
||||
}
|
||||
|
||||
info!("Concepts stored successfully");
|
||||
|
||||
info!("Performing knowledge graph search...");
|
||||
let query = "How does Rust manage memory?";
|
||||
let query_embedding = service.embed(query, &options).await?;
|
||||
|
||||
let results = store.search(&query_embedding, 3, None).await?;
|
||||
|
||||
info!("Most relevant concepts for: '{}'", query);
|
||||
for (i, result) in results.iter().enumerate() {
|
||||
let concept = result.metadata["concept"].as_str().unwrap_or("N/A");
|
||||
let description = result.metadata["description"].as_str().unwrap_or("N/A");
|
||||
info!(" {}. {} [score: {:.4}]", i + 1, concept, result.score);
|
||||
info!(" {}", description);
|
||||
}
|
||||
|
||||
let count = store.count().await?;
|
||||
info!("Total concepts in graph: {}", count);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
312
crates/stratum-embeddings/src/batch.rs
Normal file
312
crates/stratum-embeddings/src/batch.rs
Normal file
@ -0,0 +1,312 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::stream::{self, StreamExt};
|
||||
use tracing::{debug, info};
|
||||
|
||||
use crate::{
|
||||
cache::{cache_key, EmbeddingCache},
|
||||
error::EmbeddingError,
|
||||
traits::{Embedding, EmbeddingOptions, EmbeddingProvider, EmbeddingResult},
|
||||
};
|
||||
|
||||
pub struct BatchProcessor<P: EmbeddingProvider, C: EmbeddingCache> {
|
||||
provider: Arc<P>,
|
||||
cache: Option<Arc<C>>,
|
||||
max_concurrent: usize,
|
||||
}
|
||||
|
||||
impl<P: EmbeddingProvider, C: EmbeddingCache> BatchProcessor<P, C> {
|
||||
pub fn new(provider: Arc<P>, cache: Option<Arc<C>>) -> Self {
|
||||
Self {
|
||||
provider,
|
||||
cache,
|
||||
max_concurrent: 10,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_concurrency(mut self, max_concurrent: usize) -> Self {
|
||||
self.max_concurrent = max_concurrent;
|
||||
self
|
||||
}
|
||||
|
||||
pub async fn process_batch(
|
||||
&self,
|
||||
texts: Vec<String>,
|
||||
options: &EmbeddingOptions,
|
||||
) -> Result<EmbeddingResult, EmbeddingError> {
|
||||
if texts.is_empty() {
|
||||
return Err(EmbeddingError::InvalidInput(
|
||||
"Texts cannot be empty".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let provider_batch_size = self.provider.max_batch_size();
|
||||
let cache_enabled = options.use_cache && self.cache.is_some();
|
||||
|
||||
let mut all_embeddings = Vec::with_capacity(texts.len());
|
||||
let mut total_tokens = 0u32;
|
||||
let mut cached_count = 0usize;
|
||||
|
||||
for chunk in texts.chunks(provider_batch_size) {
|
||||
let (cache_hits, cache_misses) = if cache_enabled {
|
||||
self.check_cache(chunk, self.provider.name(), self.provider.model())
|
||||
.await
|
||||
} else {
|
||||
(vec![None; chunk.len()], (0..chunk.len()).collect())
|
||||
};
|
||||
|
||||
let mut chunk_embeddings = cache_hits;
|
||||
cached_count += chunk_embeddings.iter().filter(|e| e.is_some()).count();
|
||||
|
||||
if !cache_misses.is_empty() {
|
||||
let texts_to_embed: Vec<&str> = cache_misses
|
||||
.iter()
|
||||
.map(|&idx| chunk[idx].as_str())
|
||||
.collect();
|
||||
|
||||
debug!(
|
||||
"Embedding {} texts (cached: {}, new: {})",
|
||||
chunk.len(),
|
||||
cached_count,
|
||||
texts_to_embed.len()
|
||||
);
|
||||
|
||||
let result = self.provider.embed_batch(&texts_to_embed, options).await?;
|
||||
|
||||
if let Some(tokens) = result.total_tokens {
|
||||
total_tokens += tokens;
|
||||
}
|
||||
|
||||
if let Some(cache) = cache_enabled.then_some(self.cache.as_ref()).flatten() {
|
||||
let cache_items = Self::build_cache_items(
|
||||
self.provider.name(),
|
||||
self.provider.model(),
|
||||
chunk,
|
||||
&cache_misses,
|
||||
&result.embeddings,
|
||||
);
|
||||
cache.insert_batch(cache_items).await;
|
||||
}
|
||||
|
||||
for (miss_idx, embedding) in cache_misses.iter().zip(result.embeddings.into_iter())
|
||||
{
|
||||
chunk_embeddings[*miss_idx] = Some(embedding);
|
||||
}
|
||||
}
|
||||
|
||||
all_embeddings.extend(
|
||||
chunk_embeddings
|
||||
.into_iter()
|
||||
.map(|e| e.expect("Missing embedding")),
|
||||
);
|
||||
}
|
||||
|
||||
info!(
|
||||
"Batch complete: {} embeddings ({} cached, {} new)",
|
||||
texts.len(),
|
||||
cached_count,
|
||||
texts.len() - cached_count
|
||||
);
|
||||
|
||||
Ok(EmbeddingResult {
|
||||
embeddings: all_embeddings,
|
||||
model: self.provider.model().to_string(),
|
||||
dimensions: self.provider.dimensions(),
|
||||
total_tokens: if total_tokens > 0 {
|
||||
Some(total_tokens)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
cached_count,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn process_stream(
|
||||
&self,
|
||||
texts: Vec<String>,
|
||||
options: &EmbeddingOptions,
|
||||
) -> Result<Vec<Embedding>, EmbeddingError> {
|
||||
let provider = Arc::clone(&self.provider);
|
||||
let cache = self.cache.clone();
|
||||
let provider_name = provider.name().to_string();
|
||||
let provider_model = provider.model().to_string();
|
||||
let opts = options.clone();
|
||||
|
||||
let embeddings: Vec<Embedding> = stream::iter(texts)
|
||||
.map(move |text| {
|
||||
let provider = Arc::clone(&provider);
|
||||
let cache = cache.clone();
|
||||
let provider_name = provider_name.clone();
|
||||
let provider_model = provider_model.clone();
|
||||
let opts = opts.clone();
|
||||
|
||||
Self::embed_with_cache(provider, cache, text, provider_name, provider_model, opts)
|
||||
})
|
||||
.buffer_unordered(self.max_concurrent)
|
||||
.collect::<Vec<Result<Embedding, EmbeddingError>>>()
|
||||
.await
|
||||
.into_iter()
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
Ok(embeddings)
|
||||
}
|
||||
|
||||
fn build_cache_items(
|
||||
provider_name: &str,
|
||||
provider_model: &str,
|
||||
chunk: &[String],
|
||||
cache_misses: &[usize],
|
||||
embeddings: &[Embedding],
|
||||
) -> Vec<(String, Embedding)> {
|
||||
cache_misses
|
||||
.iter()
|
||||
.zip(embeddings.iter())
|
||||
.map(|(&idx, emb)| {
|
||||
(
|
||||
cache_key(provider_name, provider_model, &chunk[idx]),
|
||||
emb.clone(),
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
async fn embed_with_cache(
|
||||
provider: Arc<P>,
|
||||
cache: Option<Arc<C>>,
|
||||
text: String,
|
||||
provider_name: String,
|
||||
provider_model: String,
|
||||
opts: EmbeddingOptions,
|
||||
) -> Result<Embedding, EmbeddingError> {
|
||||
let key = cache_key(&provider_name, &provider_model, &text);
|
||||
|
||||
if opts.use_cache {
|
||||
if let Some(cached) = Self::try_get_cached(&cache, &key).await {
|
||||
return Ok(cached);
|
||||
}
|
||||
}
|
||||
|
||||
let embedding = provider.embed(&text, &opts).await?;
|
||||
|
||||
if opts.use_cache {
|
||||
Self::cache_insert(&cache, &key, embedding.clone()).await;
|
||||
}
|
||||
|
||||
Ok(embedding)
|
||||
}
|
||||
|
||||
async fn try_get_cached(cache: &Option<Arc<C>>, key: &str) -> Option<Embedding> {
|
||||
match cache {
|
||||
Some(c) => c.get(key).await,
|
||||
None => None,
|
||||
}
|
||||
}
|
||||
|
||||
async fn cache_insert(cache: &Option<Arc<C>>, key: &str, embedding: Embedding) {
|
||||
if let Some(c) = cache {
|
||||
c.insert(key, embedding).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn check_cache(
|
||||
&self,
|
||||
texts: &[String],
|
||||
provider_name: &str,
|
||||
model_name: &str,
|
||||
) -> (Vec<Option<Embedding>>, Vec<usize>) {
|
||||
let cache = match &self.cache {
|
||||
Some(c) => c,
|
||||
None => return (vec![None; texts.len()], (0..texts.len()).collect()),
|
||||
};
|
||||
|
||||
let keys: Vec<String> = texts
|
||||
.iter()
|
||||
.map(|text| cache_key(provider_name, model_name, text))
|
||||
.collect();
|
||||
|
||||
let cached = cache.get_batch(&keys).await;
|
||||
let misses: Vec<usize> = cached
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, e)| e.is_none())
|
||||
.map(|(i, _)| i)
|
||||
.collect();
|
||||
|
||||
(cached, misses)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::time::Duration;
|
||||
|
||||
use super::*;
|
||||
use crate::{cache::MemoryCache, providers::FastEmbedProvider};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_batch_processor_no_cache() {
|
||||
let provider = Arc::new(FastEmbedProvider::small().expect("Failed to init"));
|
||||
let processor: BatchProcessor<_, MemoryCache> = BatchProcessor::new(provider, None);
|
||||
|
||||
let texts = vec!["Hello world".to_string(), "Goodbye world".to_string()];
|
||||
|
||||
let options = EmbeddingOptions::no_cache();
|
||||
let result = processor
|
||||
.process_batch(texts, &options)
|
||||
.await
|
||||
.expect("Failed to process");
|
||||
|
||||
assert_eq!(result.embeddings.len(), 2);
|
||||
assert_eq!(result.cached_count, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_batch_processor_with_cache() {
|
||||
let provider = Arc::new(FastEmbedProvider::small().expect("Failed to init"));
|
||||
let cache = Arc::new(MemoryCache::new(100, Duration::from_secs(60)));
|
||||
let processor = BatchProcessor::new(provider, Some(cache));
|
||||
|
||||
let texts = vec!["Hello world".to_string(), "Goodbye world".to_string()];
|
||||
|
||||
let options = EmbeddingOptions::default_with_cache();
|
||||
let result1 = processor
|
||||
.process_batch(texts.clone(), &options)
|
||||
.await
|
||||
.expect("Failed first batch");
|
||||
|
||||
assert_eq!(result1.embeddings.len(), 2);
|
||||
assert_eq!(result1.cached_count, 0);
|
||||
|
||||
let result2 = processor
|
||||
.process_batch(texts, &options)
|
||||
.await
|
||||
.expect("Failed second batch");
|
||||
|
||||
assert_eq!(result2.embeddings.len(), 2);
|
||||
assert_eq!(result2.cached_count, 2);
|
||||
|
||||
assert_eq!(result1.embeddings, result2.embeddings);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_batch_processor_stream() {
|
||||
let provider = Arc::new(FastEmbedProvider::small().expect("Failed to init"));
|
||||
let cache = Arc::new(MemoryCache::with_defaults());
|
||||
let processor = BatchProcessor::new(provider, Some(cache)).with_concurrency(2);
|
||||
|
||||
let texts = vec![
|
||||
"Text 1".to_string(),
|
||||
"Text 2".to_string(),
|
||||
"Text 3".to_string(),
|
||||
"Text 4".to_string(),
|
||||
];
|
||||
|
||||
let options = EmbeddingOptions::default_with_cache();
|
||||
let embeddings = processor
|
||||
.process_stream(texts, &options)
|
||||
.await
|
||||
.expect("Failed stream");
|
||||
|
||||
assert_eq!(embeddings.len(), 4);
|
||||
}
|
||||
}
|
||||
167
crates/stratum-embeddings/src/cache/memory.rs
vendored
Normal file
167
crates/stratum-embeddings/src/cache/memory.rs
vendored
Normal file
@ -0,0 +1,167 @@
|
||||
#[cfg(feature = "memory-cache")]
|
||||
use std::time::Duration;
|
||||
|
||||
#[cfg(feature = "memory-cache")]
|
||||
use async_trait::async_trait;
|
||||
#[cfg(feature = "memory-cache")]
|
||||
use moka::future::Cache;
|
||||
|
||||
#[cfg(feature = "memory-cache")]
|
||||
use crate::{cache::EmbeddingCache, traits::Embedding};
|
||||
|
||||
pub struct MemoryCache {
|
||||
cache: Cache<String, Embedding>,
|
||||
}
|
||||
|
||||
impl MemoryCache {
|
||||
pub fn new(max_capacity: u64, ttl: Duration) -> Self {
|
||||
let cache = Cache::builder()
|
||||
.max_capacity(max_capacity)
|
||||
.time_to_live(ttl)
|
||||
.build();
|
||||
|
||||
Self { cache }
|
||||
}
|
||||
|
||||
pub fn with_defaults() -> Self {
|
||||
Self::new(10_000, Duration::from_secs(3600))
|
||||
}
|
||||
|
||||
pub fn unlimited(ttl: Duration) -> Self {
|
||||
let cache = Cache::builder().time_to_live(ttl).build();
|
||||
|
||||
Self { cache }
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for MemoryCache {
|
||||
fn default() -> Self {
|
||||
Self::with_defaults()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl EmbeddingCache for MemoryCache {
|
||||
async fn get(&self, key: &str) -> Option<Embedding> {
|
||||
self.cache.get(key).await
|
||||
}
|
||||
|
||||
async fn insert(&self, key: &str, embedding: Embedding) {
|
||||
self.cache.insert(key.to_string(), embedding).await;
|
||||
self.cache.run_pending_tasks().await;
|
||||
}
|
||||
|
||||
async fn get_batch(&self, keys: &[String]) -> Vec<Option<Embedding>> {
|
||||
let mut results = Vec::with_capacity(keys.len());
|
||||
for key in keys {
|
||||
results.push(self.cache.get(key).await);
|
||||
}
|
||||
results
|
||||
}
|
||||
|
||||
async fn insert_batch(&self, items: Vec<(String, Embedding)>) {
|
||||
for (key, embedding) in items {
|
||||
self.cache.insert(key, embedding).await;
|
||||
}
|
||||
self.cache.run_pending_tasks().await;
|
||||
}
|
||||
|
||||
async fn invalidate(&self, key: &str) {
|
||||
self.cache.invalidate(key).await;
|
||||
self.cache.run_pending_tasks().await;
|
||||
}
|
||||
|
||||
async fn clear(&self) {
|
||||
self.cache.invalidate_all();
|
||||
self.cache.run_pending_tasks().await;
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.cache.entry_count() as usize
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_memory_cache_basic() {
|
||||
let cache = MemoryCache::with_defaults();
|
||||
|
||||
let embedding = vec![1.0, 2.0, 3.0];
|
||||
cache.insert("test_key", embedding.clone()).await;
|
||||
|
||||
let retrieved = cache.get("test_key").await;
|
||||
assert_eq!(retrieved, Some(embedding));
|
||||
|
||||
let missing = cache.get("missing_key").await;
|
||||
assert_eq!(missing, None);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_memory_cache_batch() {
|
||||
let cache = MemoryCache::with_defaults();
|
||||
|
||||
let items = vec![
|
||||
("key1".to_string(), vec![1.0, 2.0]),
|
||||
("key2".to_string(), vec![3.0, 4.0]),
|
||||
("key3".to_string(), vec![5.0, 6.0]),
|
||||
];
|
||||
|
||||
cache.insert_batch(items.clone()).await;
|
||||
|
||||
let keys = vec![
|
||||
"key1".to_string(),
|
||||
"key2".to_string(),
|
||||
"missing".to_string(),
|
||||
];
|
||||
let results = cache.get_batch(&keys).await;
|
||||
|
||||
assert_eq!(results.len(), 3);
|
||||
assert_eq!(results[0], Some(vec![1.0, 2.0]));
|
||||
assert_eq!(results[1], Some(vec![3.0, 4.0]));
|
||||
assert_eq!(results[2], None);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_memory_cache_invalidate() {
|
||||
let cache = MemoryCache::with_defaults();
|
||||
|
||||
cache.insert("key1", vec![1.0, 2.0]).await;
|
||||
cache.insert("key2", vec![3.0, 4.0]).await;
|
||||
|
||||
assert!(cache.get("key1").await.is_some());
|
||||
assert!(cache.get("key2").await.is_some());
|
||||
|
||||
cache.invalidate("key1").await;
|
||||
|
||||
assert!(cache.get("key1").await.is_none());
|
||||
assert!(cache.get("key2").await.is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_memory_cache_clear() {
|
||||
let cache = MemoryCache::with_defaults();
|
||||
|
||||
cache.insert("key1", vec![1.0]).await;
|
||||
cache.insert("key2", vec![2.0]).await;
|
||||
|
||||
cache.clear().await;
|
||||
|
||||
assert!(cache.get("key1").await.is_none());
|
||||
assert!(cache.get("key2").await.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_memory_cache_ttl() {
|
||||
let cache = MemoryCache::new(1000, Duration::from_millis(100));
|
||||
|
||||
cache.insert("key1", vec![1.0, 2.0]).await;
|
||||
assert!(cache.get("key1").await.is_some());
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(150)).await;
|
||||
|
||||
assert!(cache.get("key1").await.is_none());
|
||||
}
|
||||
}
|
||||
47
crates/stratum-embeddings/src/cache/mod.rs
vendored
Normal file
47
crates/stratum-embeddings/src/cache/mod.rs
vendored
Normal file
@ -0,0 +1,47 @@
|
||||
#[cfg(feature = "memory-cache")]
|
||||
pub mod memory;
|
||||
#[cfg(feature = "persistent-cache")]
|
||||
pub mod persistent;
|
||||
|
||||
use async_trait::async_trait;
|
||||
#[cfg(feature = "memory-cache")]
|
||||
pub use memory::MemoryCache;
|
||||
#[cfg(feature = "persistent-cache")]
|
||||
pub use persistent::PersistentCache;
|
||||
|
||||
use crate::traits::Embedding;
|
||||
|
||||
#[async_trait]
|
||||
pub trait EmbeddingCache: Send + Sync {
|
||||
async fn get(&self, key: &str) -> Option<Embedding>;
|
||||
async fn insert(&self, key: &str, embedding: Embedding);
|
||||
async fn get_batch(&self, keys: &[String]) -> Vec<Option<Embedding>>;
|
||||
async fn insert_batch(&self, items: Vec<(String, Embedding)>);
|
||||
async fn invalidate(&self, key: &str);
|
||||
async fn clear(&self);
|
||||
fn size(&self) -> usize;
|
||||
}
|
||||
|
||||
pub fn cache_key(provider: &str, model: &str, text: &str) -> String {
|
||||
use xxhash_rust::xxh3::xxh3_64;
|
||||
let hash = xxh3_64(format!("{}:{}:{}", provider, model, text).as_bytes());
|
||||
format!("{}:{}:{:x}", provider, model, hash)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_cache_key_consistency() {
|
||||
let key1 = cache_key("fastembed", "bge-small", "hello world");
|
||||
let key2 = cache_key("fastembed", "bge-small", "hello world");
|
||||
assert_eq!(key1, key2);
|
||||
|
||||
let key3 = cache_key("fastembed", "bge-small", "hello world!");
|
||||
assert_ne!(key1, key3);
|
||||
|
||||
let key4 = cache_key("openai", "bge-small", "hello world");
|
||||
assert_ne!(key1, key4);
|
||||
}
|
||||
}
|
||||
152
crates/stratum-embeddings/src/cache/persistent.rs
vendored
Normal file
152
crates/stratum-embeddings/src/cache/persistent.rs
vendored
Normal file
@ -0,0 +1,152 @@
|
||||
#[cfg(feature = "persistent-cache")]
|
||||
use std::path::Path;
|
||||
|
||||
#[cfg(feature = "persistent-cache")]
|
||||
use async_trait::async_trait;
|
||||
#[cfg(feature = "persistent-cache")]
|
||||
use sled::Db;
|
||||
|
||||
#[cfg(feature = "persistent-cache")]
|
||||
use crate::{cache::EmbeddingCache, error::EmbeddingError, traits::Embedding};
|
||||
|
||||
pub struct PersistentCache {
|
||||
db: Db,
|
||||
}
|
||||
|
||||
impl PersistentCache {
|
||||
pub fn new<P: AsRef<Path>>(path: P) -> Result<Self, EmbeddingError> {
|
||||
let db = sled::open(path)
|
||||
.map_err(|e| EmbeddingError::CacheError(format!("Failed to open sled db: {}", e)))?;
|
||||
|
||||
Ok(Self { db })
|
||||
}
|
||||
|
||||
pub fn in_memory() -> Result<Self, EmbeddingError> {
|
||||
let db = sled::Config::new().temporary(true).open().map_err(|e| {
|
||||
EmbeddingError::CacheError(format!("Failed to create in-memory sled db: {}", e))
|
||||
})?;
|
||||
|
||||
Ok(Self { db })
|
||||
}
|
||||
|
||||
fn serialize_embedding(embedding: &Embedding) -> Result<Vec<u8>, EmbeddingError> {
|
||||
serde_json::to_vec(embedding)
|
||||
.map_err(|e| EmbeddingError::SerializationError(format!("Embedding serialize: {}", e)))
|
||||
}
|
||||
|
||||
fn deserialize_embedding(data: &[u8]) -> Result<Embedding, EmbeddingError> {
|
||||
serde_json::from_slice(data).map_err(|e| {
|
||||
EmbeddingError::SerializationError(format!("Embedding deserialize: {}", e))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl EmbeddingCache for PersistentCache {
|
||||
async fn get(&self, key: &str) -> Option<Embedding> {
|
||||
self.db
|
||||
.get(key)
|
||||
.ok()
|
||||
.flatten()
|
||||
.and_then(|bytes| Self::deserialize_embedding(&bytes).ok())
|
||||
}
|
||||
|
||||
async fn insert(&self, key: &str, embedding: Embedding) {
|
||||
if let Ok(bytes) = Self::serialize_embedding(&embedding) {
|
||||
let _ = self.db.insert(key, bytes);
|
||||
let _ = self.db.flush();
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_batch(&self, keys: &[String]) -> Vec<Option<Embedding>> {
|
||||
keys.iter()
|
||||
.map(|key| {
|
||||
self.db
|
||||
.get(key)
|
||||
.ok()
|
||||
.flatten()
|
||||
.and_then(|bytes| Self::deserialize_embedding(&bytes).ok())
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
async fn insert_batch(&self, items: Vec<(String, Embedding)>) {
|
||||
for (key, embedding) in items {
|
||||
if let Ok(bytes) = Self::serialize_embedding(&embedding) {
|
||||
let _ = self.db.insert(key, bytes);
|
||||
}
|
||||
}
|
||||
let _ = self.db.flush();
|
||||
}
|
||||
|
||||
async fn invalidate(&self, key: &str) {
|
||||
let _ = self.db.remove(key);
|
||||
let _ = self.db.flush();
|
||||
}
|
||||
|
||||
async fn clear(&self) {
|
||||
let _ = self.db.clear();
|
||||
let _ = self.db.flush();
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.db.len()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_persistent_cache_in_memory() {
|
||||
let cache = PersistentCache::in_memory().expect("Failed to create cache");
|
||||
|
||||
let embedding = vec![1.0, 2.0, 3.0];
|
||||
cache.insert("test_key", embedding.clone()).await;
|
||||
|
||||
let retrieved = cache.get("test_key").await;
|
||||
assert_eq!(retrieved, Some(embedding));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_persistent_cache_batch() {
|
||||
let cache = PersistentCache::in_memory().expect("Failed to create cache");
|
||||
|
||||
let items = vec![
|
||||
("key1".to_string(), vec![1.0, 2.0]),
|
||||
("key2".to_string(), vec![3.0, 4.0]),
|
||||
];
|
||||
|
||||
cache.insert_batch(items).await;
|
||||
|
||||
let keys = vec!["key1".to_string(), "key2".to_string()];
|
||||
let results = cache.get_batch(&keys).await;
|
||||
|
||||
assert_eq!(results[0], Some(vec![1.0, 2.0]));
|
||||
assert_eq!(results[1], Some(vec![3.0, 4.0]));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_persistent_cache_invalidate() {
|
||||
let cache = PersistentCache::in_memory().expect("Failed to create cache");
|
||||
|
||||
cache.insert("key1", vec![1.0]).await;
|
||||
assert!(cache.get("key1").await.is_some());
|
||||
|
||||
cache.invalidate("key1").await;
|
||||
assert!(cache.get("key1").await.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_persistent_cache_clear() {
|
||||
let cache = PersistentCache::in_memory().expect("Failed to create cache");
|
||||
|
||||
cache.insert("key1", vec![1.0]).await;
|
||||
cache.insert("key2", vec![2.0]).await;
|
||||
assert_eq!(cache.size(), 2);
|
||||
|
||||
cache.clear().await;
|
||||
assert_eq!(cache.size(), 0);
|
||||
}
|
||||
}
|
||||
153
crates/stratum-embeddings/src/config.rs
Normal file
153
crates/stratum-embeddings/src/config.rs
Normal file
@ -0,0 +1,153 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EmbeddingConfig {
|
||||
pub provider: ProviderConfig,
|
||||
pub cache: CacheConfig,
|
||||
#[serde(default)]
|
||||
pub batch: BatchConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "lowercase")]
|
||||
pub enum ProviderConfig {
|
||||
FastEmbed {
|
||||
model: String,
|
||||
},
|
||||
OpenAI {
|
||||
api_key: String,
|
||||
model: String,
|
||||
base_url: Option<String>,
|
||||
},
|
||||
Ollama {
|
||||
model: String,
|
||||
base_url: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CacheConfig {
|
||||
pub enabled: bool,
|
||||
pub max_capacity: u64,
|
||||
#[serde(with = "humantime_serde")]
|
||||
pub ttl: Duration,
|
||||
}
|
||||
|
||||
impl Default for CacheConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: true,
|
||||
max_capacity: 10_000,
|
||||
ttl: Duration::from_secs(3600),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct BatchConfig {
|
||||
pub max_concurrent: usize,
|
||||
pub chunk_size: Option<usize>,
|
||||
}
|
||||
|
||||
impl Default for BatchConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_concurrent: 10,
|
||||
chunk_size: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct VectorStoreSettings {
|
||||
pub dimensions: usize,
|
||||
pub metric: String,
|
||||
}
|
||||
|
||||
impl EmbeddingConfig {
|
||||
pub fn from_env(provider_type: &str) -> Result<Self, crate::error::EmbeddingError> {
|
||||
let provider = match provider_type {
|
||||
"fastembed" => {
|
||||
let model =
|
||||
std::env::var("FASTEMBED_MODEL").unwrap_or_else(|_| "bge-small-en".to_string());
|
||||
ProviderConfig::FastEmbed { model }
|
||||
}
|
||||
"openai" => {
|
||||
let api_key = std::env::var("OPENAI_API_KEY").map_err(|_| {
|
||||
crate::error::EmbeddingError::ConfigError("OPENAI_API_KEY not set".to_string())
|
||||
})?;
|
||||
let model = std::env::var("OPENAI_MODEL")
|
||||
.unwrap_or_else(|_| "text-embedding-3-small".to_string());
|
||||
let base_url = std::env::var("OPENAI_BASE_URL").ok();
|
||||
ProviderConfig::OpenAI {
|
||||
api_key,
|
||||
model,
|
||||
base_url,
|
||||
}
|
||||
}
|
||||
"ollama" => {
|
||||
let model = std::env::var("OLLAMA_MODEL")
|
||||
.unwrap_or_else(|_| "nomic-embed-text".to_string());
|
||||
let base_url = std::env::var("OLLAMA_BASE_URL").ok();
|
||||
ProviderConfig::Ollama { model, base_url }
|
||||
}
|
||||
_ => {
|
||||
return Err(crate::error::EmbeddingError::ConfigError(format!(
|
||||
"Unknown provider type: {}",
|
||||
provider_type
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
provider,
|
||||
cache: CacheConfig::default(),
|
||||
batch: BatchConfig::default(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn with_cache(mut self, config: CacheConfig) -> Self {
|
||||
self.cache = config;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_batch(mut self, config: BatchConfig) -> Self {
|
||||
self.batch = config;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_config_serialization() {
|
||||
let config = EmbeddingConfig {
|
||||
provider: ProviderConfig::FastEmbed {
|
||||
model: "bge-small-en".to_string(),
|
||||
},
|
||||
cache: CacheConfig::default(),
|
||||
batch: BatchConfig::default(),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&config).expect("Failed to serialize");
|
||||
let deserialized: EmbeddingConfig =
|
||||
serde_json::from_str(&json).expect("Failed to deserialize");
|
||||
|
||||
match deserialized.provider {
|
||||
ProviderConfig::FastEmbed { model } => assert_eq!(model, "bge-small-en"),
|
||||
_ => panic!("Wrong provider type"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_config_defaults() {
|
||||
let config = CacheConfig::default();
|
||||
assert!(config.enabled);
|
||||
assert_eq!(config.max_capacity, 10_000);
|
||||
assert_eq!(config.ttl, Duration::from_secs(3600));
|
||||
}
|
||||
}
|
||||
76
crates/stratum-embeddings/src/error.rs
Normal file
76
crates/stratum-embeddings/src/error.rs
Normal file
@ -0,0 +1,76 @@
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum EmbeddingError {
|
||||
#[error("Provider initialization failed: {0}")]
|
||||
Initialization(String),
|
||||
|
||||
#[error("Provider not available: {0}")]
|
||||
ProviderUnavailable(String),
|
||||
|
||||
#[error("API request failed: {0}")]
|
||||
ApiError(String),
|
||||
|
||||
#[error("Invalid input: {0}")]
|
||||
InvalidInput(String),
|
||||
|
||||
#[error("Dimension mismatch: expected {expected}, got {actual}")]
|
||||
DimensionMismatch { expected: usize, actual: usize },
|
||||
|
||||
#[error("Batch size {size} exceeds maximum {max}")]
|
||||
BatchSizeExceeded { size: usize, max: usize },
|
||||
|
||||
#[error("Cache error: {0}")]
|
||||
CacheError(String),
|
||||
|
||||
#[error("Store error: {0}")]
|
||||
StoreError(String),
|
||||
|
||||
#[error("Serialization error: {0}")]
|
||||
SerializationError(String),
|
||||
|
||||
#[error("Rate limit exceeded: {0}")]
|
||||
RateLimitExceeded(String),
|
||||
|
||||
#[error("Timeout: {0}")]
|
||||
Timeout(String),
|
||||
|
||||
#[error("Configuration error: {0}")]
|
||||
ConfigError(String),
|
||||
|
||||
#[error("IO error: {0}")]
|
||||
IoError(String),
|
||||
|
||||
#[error("HTTP error: {0}")]
|
||||
HttpError(String),
|
||||
|
||||
#[error(transparent)]
|
||||
Other(#[from] Box<dyn std::error::Error + Send + Sync>),
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for EmbeddingError {
|
||||
fn from(err: std::io::Error) -> Self {
|
||||
Self::IoError(err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<serde_json::Error> for EmbeddingError {
|
||||
fn from(err: serde_json::Error) -> Self {
|
||||
Self::SerializationError(err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "reqwest")]
|
||||
impl From<reqwest::Error> for EmbeddingError {
|
||||
fn from(err: reqwest::Error) -> Self {
|
||||
if err.is_timeout() {
|
||||
Self::Timeout(err.to_string())
|
||||
} else if err.is_status() {
|
||||
Self::HttpError(format!("HTTP {}: {}", err.status().unwrap(), err))
|
||||
} else {
|
||||
Self::ApiError(err.to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, EmbeddingError>;
|
||||
34
crates/stratum-embeddings/src/lib.rs
Normal file
34
crates/stratum-embeddings/src/lib.rs
Normal file
@ -0,0 +1,34 @@
|
||||
pub mod batch;
|
||||
pub mod cache;
|
||||
pub mod config;
|
||||
pub mod error;
|
||||
pub mod metrics;
|
||||
pub mod providers;
|
||||
pub mod service;
|
||||
pub mod store;
|
||||
pub mod traits;
|
||||
|
||||
#[cfg(feature = "memory-cache")]
|
||||
pub use cache::MemoryCache;
|
||||
#[cfg(feature = "persistent-cache")]
|
||||
pub use cache::PersistentCache;
|
||||
pub use config::{BatchConfig, CacheConfig, EmbeddingConfig, ProviderConfig};
|
||||
pub use error::{EmbeddingError, Result};
|
||||
#[cfg(feature = "cohere-provider")]
|
||||
pub use providers::cohere::{CohereModel, CohereProvider};
|
||||
#[cfg(feature = "fastembed-provider")]
|
||||
pub use providers::fastembed::{FastEmbedModel, FastEmbedProvider};
|
||||
#[cfg(feature = "huggingface-provider")]
|
||||
pub use providers::huggingface::{HuggingFaceModel, HuggingFaceProvider};
|
||||
#[cfg(feature = "ollama-provider")]
|
||||
pub use providers::ollama::{OllamaModel, OllamaProvider};
|
||||
#[cfg(feature = "openai-provider")]
|
||||
pub use providers::openai::{OpenAiModel, OpenAiProvider};
|
||||
#[cfg(feature = "voyage-provider")]
|
||||
pub use providers::voyage::{VoyageModel, VoyageProvider};
|
||||
pub use service::EmbeddingService;
|
||||
pub use store::*;
|
||||
pub use traits::{
|
||||
cosine_similarity, euclidean_distance, normalize_embedding, Embedding, EmbeddingOptions,
|
||||
EmbeddingProvider, EmbeddingResult, ProviderInfo,
|
||||
};
|
||||
195
crates/stratum-embeddings/src/metrics.rs
Normal file
195
crates/stratum-embeddings/src/metrics.rs
Normal file
@ -0,0 +1,195 @@
|
||||
#[cfg(feature = "metrics")]
|
||||
use std::sync::OnceLock;
|
||||
|
||||
#[cfg(feature = "metrics")]
|
||||
use prometheus::{
|
||||
register_histogram_vec, register_int_counter_vec, HistogramOpts, HistogramVec, IntCounterVec,
|
||||
Opts,
|
||||
};
|
||||
|
||||
#[cfg(feature = "metrics")]
|
||||
static EMBEDDING_REQUESTS: OnceLock<IntCounterVec> = OnceLock::new();
|
||||
#[cfg(feature = "metrics")]
|
||||
static EMBEDDING_ERRORS: OnceLock<IntCounterVec> = OnceLock::new();
|
||||
#[cfg(feature = "metrics")]
|
||||
static EMBEDDING_DURATION: OnceLock<HistogramVec> = OnceLock::new();
|
||||
#[cfg(feature = "metrics")]
|
||||
static CACHE_HITS: OnceLock<IntCounterVec> = OnceLock::new();
|
||||
#[cfg(feature = "metrics")]
|
||||
static CACHE_MISSES: OnceLock<IntCounterVec> = OnceLock::new();
|
||||
#[cfg(feature = "metrics")]
|
||||
static TOKENS_PROCESSED: OnceLock<IntCounterVec> = OnceLock::new();
|
||||
|
||||
#[cfg(feature = "metrics")]
|
||||
pub fn init_metrics() -> Result<(), Box<dyn std::error::Error>> {
|
||||
EMBEDDING_REQUESTS.get_or_init(|| {
|
||||
register_int_counter_vec!(
|
||||
Opts::new(
|
||||
"embedding_requests_total",
|
||||
"Total number of embedding requests"
|
||||
),
|
||||
&["provider", "model"]
|
||||
)
|
||||
.expect("Failed to register embedding_requests_total")
|
||||
});
|
||||
|
||||
EMBEDDING_ERRORS.get_or_init(|| {
|
||||
register_int_counter_vec!(
|
||||
Opts::new("embedding_errors_total", "Total number of embedding errors"),
|
||||
&["provider", "model", "error_type"]
|
||||
)
|
||||
.expect("Failed to register embedding_errors_total")
|
||||
});
|
||||
|
||||
EMBEDDING_DURATION.get_or_init(|| {
|
||||
register_histogram_vec!(
|
||||
HistogramOpts::new("embedding_duration_seconds", "Embedding request duration")
|
||||
.buckets(vec![0.001, 0.01, 0.1, 0.5, 1.0, 5.0, 10.0]),
|
||||
&["provider", "model"]
|
||||
)
|
||||
.expect("Failed to register embedding_duration_seconds")
|
||||
});
|
||||
|
||||
CACHE_HITS.get_or_init(|| {
|
||||
register_int_counter_vec!(
|
||||
Opts::new("embedding_cache_hits_total", "Total cache hits"),
|
||||
&["provider", "model"]
|
||||
)
|
||||
.expect("Failed to register embedding_cache_hits_total")
|
||||
});
|
||||
|
||||
CACHE_MISSES.get_or_init(|| {
|
||||
register_int_counter_vec!(
|
||||
Opts::new("embedding_cache_misses_total", "Total cache misses"),
|
||||
&["provider", "model"]
|
||||
)
|
||||
.expect("Failed to register embedding_cache_misses_total")
|
||||
});
|
||||
|
||||
TOKENS_PROCESSED.get_or_init(|| {
|
||||
register_int_counter_vec!(
|
||||
Opts::new("embedding_tokens_processed_total", "Total tokens processed"),
|
||||
&["provider", "model"]
|
||||
)
|
||||
.expect("Failed to register embedding_tokens_processed_total")
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(feature = "metrics")]
|
||||
pub fn record_request(provider: &str, model: &str) {
|
||||
if let Some(counter) = EMBEDDING_REQUESTS.get() {
|
||||
counter.with_label_values(&[provider, model]).inc();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "metrics")]
|
||||
pub fn record_error(provider: &str, model: &str, error_type: &str) {
|
||||
if let Some(counter) = EMBEDDING_ERRORS.get() {
|
||||
counter
|
||||
.with_label_values(&[provider, model, error_type])
|
||||
.inc();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "metrics")]
|
||||
pub fn record_duration(provider: &str, model: &str, duration_secs: f64) {
|
||||
if let Some(histogram) = EMBEDDING_DURATION.get() {
|
||||
histogram
|
||||
.with_label_values(&[provider, model])
|
||||
.observe(duration_secs);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "metrics")]
|
||||
pub fn record_cache_hit(provider: &str, model: &str) {
|
||||
if let Some(counter) = CACHE_HITS.get() {
|
||||
counter.with_label_values(&[provider, model]).inc();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "metrics")]
|
||||
pub fn record_cache_miss(provider: &str, model: &str) {
|
||||
if let Some(counter) = CACHE_MISSES.get() {
|
||||
counter.with_label_values(&[provider, model]).inc();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "metrics")]
|
||||
pub fn record_tokens(provider: &str, model: &str, tokens: u64) {
|
||||
if let Some(counter) = TOKENS_PROCESSED.get() {
|
||||
counter.with_label_values(&[provider, model]).inc_by(tokens);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "metrics")]
|
||||
pub struct MetricsGuard {
|
||||
provider: String,
|
||||
model: String,
|
||||
start: std::time::Instant,
|
||||
}
|
||||
|
||||
#[cfg(feature = "metrics")]
|
||||
impl MetricsGuard {
|
||||
pub fn new(provider: &str, model: &str) -> Self {
|
||||
record_request(provider, model);
|
||||
Self {
|
||||
provider: provider.to_string(),
|
||||
model: model.to_string(),
|
||||
start: std::time::Instant::now(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn record_success(self, tokens: Option<u32>) {
|
||||
let duration = self.start.elapsed().as_secs_f64();
|
||||
record_duration(&self.provider, &self.model, duration);
|
||||
|
||||
if let Some(token_count) = tokens {
|
||||
record_tokens(&self.provider, &self.model, token_count as u64);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn record_error(self, error_type: &str) {
|
||||
record_error(&self.provider, &self.model, error_type);
|
||||
let duration = self.start.elapsed().as_secs_f64();
|
||||
record_duration(&self.provider, &self.model, duration);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "metrics"))]
|
||||
pub fn init_metrics() -> Result<(), Box<dyn std::error::Error>> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[cfg(feature = "metrics")]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_metrics_initialization() {
|
||||
let result = init_metrics();
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_record_request() {
|
||||
init_metrics().unwrap();
|
||||
record_request("test-provider", "test-model");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_metrics_guard() {
|
||||
init_metrics().unwrap();
|
||||
let guard = MetricsGuard::new("test", "model");
|
||||
guard.record_success(Some(100));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_metrics() {
|
||||
init_metrics().unwrap();
|
||||
record_cache_hit("test", "model");
|
||||
record_cache_miss("test", "model");
|
||||
}
|
||||
}
|
||||
251
crates/stratum-embeddings/src/providers/cohere.rs
Normal file
251
crates/stratum-embeddings/src/providers/cohere.rs
Normal file
@ -0,0 +1,251 @@
|
||||
#[cfg(feature = "cohere-provider")]
|
||||
use async_trait::async_trait;
|
||||
#[cfg(feature = "cohere-provider")]
|
||||
use reqwest::Client;
|
||||
#[cfg(feature = "cohere-provider")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[cfg(feature = "cohere-provider")]
|
||||
use crate::{
|
||||
error::EmbeddingError,
|
||||
traits::{Embedding, EmbeddingOptions, EmbeddingProvider, EmbeddingResult},
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Default)]
|
||||
pub enum CohereModel {
|
||||
#[default]
|
||||
EmbedEnglishV3,
|
||||
EmbedMultilingualV3,
|
||||
EmbedEnglishLightV3,
|
||||
EmbedMultilingualLightV3,
|
||||
EmbedEnglishV2,
|
||||
EmbedMultilingualV2,
|
||||
}
|
||||
|
||||
impl CohereModel {
|
||||
pub fn model_name(&self) -> &'static str {
|
||||
match self {
|
||||
Self::EmbedEnglishV3 => "embed-english-v3.0",
|
||||
Self::EmbedMultilingualV3 => "embed-multilingual-v3.0",
|
||||
Self::EmbedEnglishLightV3 => "embed-english-light-v3.0",
|
||||
Self::EmbedMultilingualLightV3 => "embed-multilingual-light-v3.0",
|
||||
Self::EmbedEnglishV2 => "embed-english-v2.0",
|
||||
Self::EmbedMultilingualV2 => "embed-multilingual-v2.0",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dimensions(&self) -> usize {
|
||||
match self {
|
||||
Self::EmbedEnglishV3 | Self::EmbedMultilingualV3 => 1024,
|
||||
Self::EmbedEnglishLightV3 | Self::EmbedMultilingualLightV3 => 384,
|
||||
Self::EmbedEnglishV2 | Self::EmbedMultilingualV2 => 4096,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct CohereEmbedRequest {
|
||||
model: String,
|
||||
texts: Vec<String>,
|
||||
input_type: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
truncate: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct CohereEmbedResponse {
|
||||
embeddings: Vec<Vec<f32>>,
|
||||
meta: CohereMeta,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct CohereMeta {
|
||||
billed_units: Option<BilledUnits>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct BilledUnits {
|
||||
input_tokens: Option<u32>,
|
||||
}
|
||||
|
||||
pub struct CohereProvider {
|
||||
client: Client,
|
||||
api_key: String,
|
||||
model: CohereModel,
|
||||
base_url: String,
|
||||
}
|
||||
|
||||
impl CohereProvider {
|
||||
pub fn new(api_key: String, model: CohereModel) -> Self {
|
||||
Self {
|
||||
client: Client::new(),
|
||||
api_key,
|
||||
model,
|
||||
base_url: "https://api.cohere.ai/v1".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_base_url(mut self, base_url: String) -> Self {
|
||||
self.base_url = base_url;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn embed_english_v3(api_key: String) -> Self {
|
||||
Self::new(api_key, CohereModel::EmbedEnglishV3)
|
||||
}
|
||||
|
||||
pub fn embed_multilingual_v3(api_key: String) -> Self {
|
||||
Self::new(api_key, CohereModel::EmbedMultilingualV3)
|
||||
}
|
||||
|
||||
async fn embed_batch_internal(
|
||||
&self,
|
||||
texts: &[&str],
|
||||
_options: &EmbeddingOptions,
|
||||
) -> Result<EmbeddingResult, EmbeddingError> {
|
||||
let request = CohereEmbedRequest {
|
||||
model: self.model.model_name().to_string(),
|
||||
texts: texts.iter().map(|s| s.to_string()).collect(),
|
||||
input_type: "search_document".to_string(),
|
||||
truncate: Some("END".to_string()),
|
||||
};
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(format!("{}/embed", self.base_url))
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| EmbeddingError::ApiError(format!("Cohere API request failed: {}", e)))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let error_text = response.text().await.unwrap_or_default();
|
||||
return Err(EmbeddingError::ApiError(format!(
|
||||
"Cohere API error {}: {}",
|
||||
status, error_text
|
||||
)));
|
||||
}
|
||||
|
||||
let result: CohereEmbedResponse = response.json().await.map_err(|e| {
|
||||
EmbeddingError::ApiError(format!("Failed to parse Cohere response: {}", e))
|
||||
})?;
|
||||
|
||||
let total_tokens = result.meta.billed_units.and_then(|b| b.input_tokens);
|
||||
|
||||
Ok(EmbeddingResult {
|
||||
embeddings: result.embeddings,
|
||||
model: self.model.model_name().to_string(),
|
||||
dimensions: self.model.dimensions(),
|
||||
total_tokens,
|
||||
cached_count: 0,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl EmbeddingProvider for CohereProvider {
|
||||
fn name(&self) -> &str {
|
||||
"cohere"
|
||||
}
|
||||
|
||||
fn model(&self) -> &str {
|
||||
self.model.model_name()
|
||||
}
|
||||
|
||||
fn dimensions(&self) -> usize {
|
||||
self.model.dimensions()
|
||||
}
|
||||
|
||||
fn is_local(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn max_tokens(&self) -> usize {
|
||||
match self.model {
|
||||
CohereModel::EmbedEnglishV3
|
||||
| CohereModel::EmbedMultilingualV3
|
||||
| CohereModel::EmbedEnglishLightV3
|
||||
| CohereModel::EmbedMultilingualLightV3 => 512,
|
||||
CohereModel::EmbedEnglishV2 | CohereModel::EmbedMultilingualV2 => 512,
|
||||
}
|
||||
}
|
||||
|
||||
fn max_batch_size(&self) -> usize {
|
||||
96
|
||||
}
|
||||
|
||||
fn cost_per_1m_tokens(&self) -> f64 {
|
||||
match self.model {
|
||||
CohereModel::EmbedEnglishV3 | CohereModel::EmbedMultilingualV3 => 0.10,
|
||||
CohereModel::EmbedEnglishLightV3 | CohereModel::EmbedMultilingualLightV3 => 0.10,
|
||||
CohereModel::EmbedEnglishV2 | CohereModel::EmbedMultilingualV2 => 0.10,
|
||||
}
|
||||
}
|
||||
|
||||
async fn is_available(&self) -> bool {
|
||||
!self.api_key.is_empty()
|
||||
}
|
||||
|
||||
async fn embed(
|
||||
&self,
|
||||
text: &str,
|
||||
options: &EmbeddingOptions,
|
||||
) -> Result<Embedding, EmbeddingError> {
|
||||
let result = self.embed_batch_internal(&[text], options).await?;
|
||||
result
|
||||
.embeddings
|
||||
.into_iter()
|
||||
.next()
|
||||
.ok_or_else(|| EmbeddingError::ApiError("No embedding returned".to_string()))
|
||||
}
|
||||
|
||||
async fn embed_batch(
|
||||
&self,
|
||||
texts: &[&str],
|
||||
options: &EmbeddingOptions,
|
||||
) -> Result<EmbeddingResult, EmbeddingError> {
|
||||
if texts.is_empty() {
|
||||
return Err(EmbeddingError::InvalidInput(
|
||||
"Cannot embed empty text list".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if texts.len() > self.max_batch_size() {
|
||||
return Err(EmbeddingError::InvalidInput(format!(
|
||||
"Batch size {} exceeds maximum {}",
|
||||
texts.len(),
|
||||
self.max_batch_size()
|
||||
)));
|
||||
}
|
||||
|
||||
self.embed_batch_internal(texts, options).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_cohere_model_names() {
|
||||
assert_eq!(
|
||||
CohereModel::EmbedEnglishV3.model_name(),
|
||||
"embed-english-v3.0"
|
||||
);
|
||||
assert_eq!(CohereModel::EmbedMultilingualV3.dimensions(), 1024);
|
||||
assert_eq!(CohereModel::EmbedEnglishLightV3.dimensions(), 384);
|
||||
assert_eq!(CohereModel::EmbedEnglishV2.dimensions(), 4096);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cohere_provider_creation() {
|
||||
let provider = CohereProvider::new("test-key".to_string(), CohereModel::EmbedEnglishV3);
|
||||
assert_eq!(provider.name(), "cohere");
|
||||
assert_eq!(provider.model(), "embed-english-v3.0");
|
||||
assert_eq!(provider.dimensions(), 1024);
|
||||
assert_eq!(provider.max_batch_size(), 96);
|
||||
}
|
||||
}
|
||||
247
crates/stratum-embeddings/src/providers/fastembed.rs
Normal file
247
crates/stratum-embeddings/src/providers/fastembed.rs
Normal file
@ -0,0 +1,247 @@
|
||||
#[cfg(feature = "fastembed-provider")]
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
#[cfg(feature = "fastembed-provider")]
|
||||
use async_trait::async_trait;
|
||||
#[cfg(feature = "fastembed-provider")]
|
||||
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
|
||||
|
||||
#[cfg(feature = "fastembed-provider")]
|
||||
use crate::{
|
||||
error::EmbeddingError,
|
||||
traits::{
|
||||
normalize_embedding, Embedding, EmbeddingOptions, EmbeddingProvider, EmbeddingResult,
|
||||
},
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Default)]
|
||||
pub enum FastEmbedModel {
|
||||
#[default]
|
||||
BgeSmallEn,
|
||||
BgeBaseEn,
|
||||
BgeLargeEn,
|
||||
AllMiniLmL6V2,
|
||||
MultilingualE5Small,
|
||||
MultilingualE5Base,
|
||||
}
|
||||
|
||||
impl FastEmbedModel {
|
||||
pub fn dimensions(&self) -> usize {
|
||||
match self {
|
||||
Self::BgeSmallEn | Self::AllMiniLmL6V2 | Self::MultilingualE5Small => 384,
|
||||
Self::BgeBaseEn | Self::MultilingualE5Base => 768,
|
||||
Self::BgeLargeEn => 1024,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn model_name(&self) -> &'static str {
|
||||
match self {
|
||||
Self::BgeSmallEn => "BAAI/bge-small-en-v1.5",
|
||||
Self::BgeBaseEn => "BAAI/bge-base-en-v1.5",
|
||||
Self::BgeLargeEn => "BAAI/bge-large-en-v1.5",
|
||||
Self::AllMiniLmL6V2 => "sentence-transformers/all-MiniLM-L6-v2",
|
||||
Self::MultilingualE5Small => "intfloat/multilingual-e5-small",
|
||||
Self::MultilingualE5Base => "intfloat/multilingual-e5-base",
|
||||
}
|
||||
}
|
||||
|
||||
fn to_fastembed_model(self) -> EmbeddingModel {
|
||||
match self {
|
||||
Self::BgeSmallEn => EmbeddingModel::BGESmallENV15,
|
||||
Self::BgeBaseEn => EmbeddingModel::BGEBaseENV15,
|
||||
Self::BgeLargeEn => EmbeddingModel::BGELargeENV15,
|
||||
Self::AllMiniLmL6V2 => EmbeddingModel::AllMiniLML6V2,
|
||||
Self::MultilingualE5Small => EmbeddingModel::MultilingualE5Small,
|
||||
Self::MultilingualE5Base => EmbeddingModel::MultilingualE5Base,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FastEmbedProvider {
|
||||
model: Arc<Mutex<TextEmbedding>>,
|
||||
model_type: FastEmbedModel,
|
||||
}
|
||||
|
||||
impl FastEmbedProvider {
|
||||
pub fn new(model_type: FastEmbedModel) -> Result<Self, EmbeddingError> {
|
||||
let options =
|
||||
InitOptions::new(model_type.to_fastembed_model()).with_show_download_progress(true);
|
||||
|
||||
let model = TextEmbedding::try_new(options)
|
||||
.map_err(|e| EmbeddingError::Initialization(e.to_string()))?;
|
||||
|
||||
Ok(Self {
|
||||
model: Arc::new(Mutex::new(model)),
|
||||
model_type,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn default_model() -> Result<Self, EmbeddingError> {
|
||||
Self::new(FastEmbedModel::default())
|
||||
}
|
||||
|
||||
pub fn small() -> Result<Self, EmbeddingError> {
|
||||
Self::new(FastEmbedModel::BgeSmallEn)
|
||||
}
|
||||
|
||||
pub fn base() -> Result<Self, EmbeddingError> {
|
||||
Self::new(FastEmbedModel::BgeBaseEn)
|
||||
}
|
||||
|
||||
pub fn large() -> Result<Self, EmbeddingError> {
|
||||
Self::new(FastEmbedModel::BgeLargeEn)
|
||||
}
|
||||
|
||||
pub fn multilingual() -> Result<Self, EmbeddingError> {
|
||||
Self::new(FastEmbedModel::MultilingualE5Base)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl EmbeddingProvider for FastEmbedProvider {
|
||||
fn name(&self) -> &str {
|
||||
"fastembed"
|
||||
}
|
||||
|
||||
fn model(&self) -> &str {
|
||||
self.model_type.model_name()
|
||||
}
|
||||
|
||||
fn dimensions(&self) -> usize {
|
||||
self.model_type.dimensions()
|
||||
}
|
||||
|
||||
fn is_local(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn max_tokens(&self) -> usize {
|
||||
512
|
||||
}
|
||||
|
||||
fn max_batch_size(&self) -> usize {
|
||||
256
|
||||
}
|
||||
|
||||
fn cost_per_1m_tokens(&self) -> f64 {
|
||||
0.0
|
||||
}
|
||||
|
||||
async fn is_available(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
async fn embed(
|
||||
&self,
|
||||
text: &str,
|
||||
options: &EmbeddingOptions,
|
||||
) -> Result<Embedding, EmbeddingError> {
|
||||
let text = text.to_string();
|
||||
let model = Arc::clone(&self.model);
|
||||
|
||||
let embeddings = tokio::task::spawn_blocking(move || {
|
||||
let mut model_guard = model.lock().expect("Failed to acquire model lock");
|
||||
model_guard.embed(vec![text], None)
|
||||
})
|
||||
.await
|
||||
.map_err(|e| EmbeddingError::ApiError(format!("Task join error: {}", e)))?
|
||||
.map_err(|e| EmbeddingError::ApiError(e.to_string()))?;
|
||||
|
||||
let mut embedding = embeddings.into_iter().next().ok_or_else(|| {
|
||||
EmbeddingError::ApiError("FastEmbed returned no embeddings".to_string())
|
||||
})?;
|
||||
|
||||
if options.normalize {
|
||||
normalize_embedding(&mut embedding);
|
||||
}
|
||||
|
||||
Ok(embedding)
|
||||
}
|
||||
|
||||
async fn embed_batch(
|
||||
&self,
|
||||
texts: &[&str],
|
||||
options: &EmbeddingOptions,
|
||||
) -> Result<EmbeddingResult, EmbeddingError> {
|
||||
if texts.len() > self.max_batch_size() {
|
||||
return Err(EmbeddingError::BatchSizeExceeded {
|
||||
size: texts.len(),
|
||||
max: self.max_batch_size(),
|
||||
});
|
||||
}
|
||||
|
||||
let texts_owned: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
|
||||
let model = Arc::clone(&self.model);
|
||||
|
||||
let mut embeddings = tokio::task::spawn_blocking(move || {
|
||||
let mut model_guard = model.lock().expect("Failed to acquire model lock");
|
||||
model_guard.embed(texts_owned, None)
|
||||
})
|
||||
.await
|
||||
.map_err(|e| EmbeddingError::ApiError(format!("Task join error: {}", e)))?
|
||||
.map_err(|e| EmbeddingError::ApiError(e.to_string()))?;
|
||||
|
||||
if options.normalize {
|
||||
embeddings.iter_mut().for_each(normalize_embedding);
|
||||
}
|
||||
|
||||
Ok(EmbeddingResult {
|
||||
embeddings,
|
||||
model: self.model().to_string(),
|
||||
dimensions: self.dimensions(),
|
||||
total_tokens: None,
|
||||
cached_count: 0,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fastembed_provider() {
|
||||
let provider = FastEmbedProvider::small().expect("Failed to initialize FastEmbed");
|
||||
|
||||
assert_eq!(provider.name(), "fastembed");
|
||||
assert_eq!(provider.dimensions(), 384);
|
||||
assert!(provider.is_local());
|
||||
assert_eq!(provider.cost_per_1m_tokens(), 0.0);
|
||||
|
||||
let options = EmbeddingOptions::default_with_cache();
|
||||
let embedding = provider
|
||||
.embed("Hello world", &options)
|
||||
.await
|
||||
.expect("Failed to embed");
|
||||
|
||||
assert_eq!(embedding.len(), 384);
|
||||
|
||||
let magnitude: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
assert!(
|
||||
(magnitude - 1.0).abs() < 0.01,
|
||||
"Embedding should be normalized"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fastembed_batch() {
|
||||
let provider = FastEmbedProvider::small().expect("Failed to initialize FastEmbed");
|
||||
|
||||
let texts = vec!["Hello world", "Goodbye world", "Machine learning"];
|
||||
let options = EmbeddingOptions::default_with_cache();
|
||||
|
||||
let result = provider
|
||||
.embed_batch(&texts, &options)
|
||||
.await
|
||||
.expect("Failed to embed batch");
|
||||
|
||||
assert_eq!(result.embeddings.len(), 3);
|
||||
assert_eq!(result.dimensions, 384);
|
||||
|
||||
for embedding in &result.embeddings {
|
||||
assert_eq!(embedding.len(), 384);
|
||||
let magnitude: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
assert!((magnitude - 1.0).abs() < 0.01);
|
||||
}
|
||||
}
|
||||
}
|
||||
344
crates/stratum-embeddings/src/providers/huggingface.rs
Normal file
344
crates/stratum-embeddings/src/providers/huggingface.rs
Normal file
@ -0,0 +1,344 @@
|
||||
#[cfg(feature = "huggingface-provider")]
|
||||
use std::time::Duration;
|
||||
|
||||
#[cfg(feature = "huggingface-provider")]
|
||||
use async_trait::async_trait;
|
||||
#[cfg(feature = "huggingface-provider")]
|
||||
use reqwest::Client;
|
||||
#[cfg(feature = "huggingface-provider")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[cfg(feature = "huggingface-provider")]
|
||||
use crate::{
|
||||
error::EmbeddingError,
|
||||
traits::{
|
||||
normalize_embedding, Embedding, EmbeddingOptions, EmbeddingProvider, EmbeddingResult,
|
||||
},
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum HuggingFaceModel {
|
||||
/// BAAI/bge-small-en-v1.5 - 384 dimensions, efficient general-purpose
|
||||
BgeSmall,
|
||||
/// BAAI/bge-base-en-v1.5 - 768 dimensions, balanced performance
|
||||
BgeBase,
|
||||
/// BAAI/bge-large-en-v1.5 - 1024 dimensions, high quality
|
||||
BgeLarge,
|
||||
/// sentence-transformers/all-MiniLM-L6-v2 - 384 dimensions, fast
|
||||
AllMiniLm,
|
||||
/// sentence-transformers/all-mpnet-base-v2 - 768 dimensions, strong baseline
|
||||
AllMpnet,
|
||||
/// Custom model with model ID and dimensions
|
||||
Custom(String, usize),
|
||||
}
|
||||
|
||||
impl Default for HuggingFaceModel {
|
||||
fn default() -> Self {
|
||||
Self::BgeSmall
|
||||
}
|
||||
}
|
||||
|
||||
impl HuggingFaceModel {
|
||||
pub fn model_id(&self) -> &str {
|
||||
match self {
|
||||
Self::BgeSmall => "BAAI/bge-small-en-v1.5",
|
||||
Self::BgeBase => "BAAI/bge-base-en-v1.5",
|
||||
Self::BgeLarge => "BAAI/bge-large-en-v1.5",
|
||||
Self::AllMiniLm => "sentence-transformers/all-MiniLM-L6-v2",
|
||||
Self::AllMpnet => "sentence-transformers/all-mpnet-base-v2",
|
||||
Self::Custom(id, _) => id.as_str(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dimensions(&self) -> usize {
|
||||
match self {
|
||||
Self::BgeSmall => 384,
|
||||
Self::BgeBase => 768,
|
||||
Self::BgeLarge => 1024,
|
||||
Self::AllMiniLm => 384,
|
||||
Self::AllMpnet => 768,
|
||||
Self::Custom(_, dims) => *dims,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_model_id(id: &str, dimensions: Option<usize>) -> Self {
|
||||
match id {
|
||||
"BAAI/bge-small-en-v1.5" => Self::BgeSmall,
|
||||
"BAAI/bge-base-en-v1.5" => Self::BgeBase,
|
||||
"BAAI/bge-large-en-v1.5" => Self::BgeLarge,
|
||||
"sentence-transformers/all-MiniLM-L6-v2" => Self::AllMiniLm,
|
||||
"sentence-transformers/all-mpnet-base-v2" => Self::AllMpnet,
|
||||
_ => Self::Custom(
|
||||
id.to_string(),
|
||||
dimensions.unwrap_or(384), // Default to 384 if unknown
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "huggingface-provider")]
|
||||
pub struct HuggingFaceProvider {
|
||||
client: Client,
|
||||
api_key: String,
|
||||
model: HuggingFaceModel,
|
||||
}
|
||||
|
||||
#[cfg(feature = "huggingface-provider")]
|
||||
#[derive(Debug, Serialize)]
|
||||
struct HFRequest {
|
||||
inputs: String,
|
||||
}
|
||||
|
||||
#[cfg(feature = "huggingface-provider")]
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum HFResponse {
|
||||
/// Single text embedding response
|
||||
Single(Vec<f32>),
|
||||
/// Batch embedding response
|
||||
Multiple(Vec<Vec<f32>>),
|
||||
}
|
||||
|
||||
#[cfg(feature = "huggingface-provider")]
|
||||
impl HuggingFaceProvider {
|
||||
const BASE_URL: &'static str = "https://api-inference.huggingface.co/pipeline/feature-extraction";
|
||||
|
||||
pub fn new(api_key: impl Into<String>, model: HuggingFaceModel) -> Result<Self, EmbeddingError> {
|
||||
let api_key = api_key.into();
|
||||
if api_key.is_empty() {
|
||||
return Err(EmbeddingError::ConfigError(
|
||||
"HuggingFace API key is empty".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(120))
|
||||
.build()
|
||||
.map_err(|e| EmbeddingError::Initialization(e.to_string()))?;
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
api_key,
|
||||
model,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn from_env(model: HuggingFaceModel) -> Result<Self, EmbeddingError> {
|
||||
let api_key = std::env::var("HUGGINGFACE_API_KEY")
|
||||
.or_else(|_| std::env::var("HF_TOKEN"))
|
||||
.map_err(|_| {
|
||||
EmbeddingError::ConfigError(
|
||||
"HUGGINGFACE_API_KEY or HF_TOKEN environment variable not set".to_string(),
|
||||
)
|
||||
})?;
|
||||
Self::new(api_key, model)
|
||||
}
|
||||
|
||||
pub fn bge_small() -> Result<Self, EmbeddingError> {
|
||||
Self::from_env(HuggingFaceModel::BgeSmall)
|
||||
}
|
||||
|
||||
pub fn bge_base() -> Result<Self, EmbeddingError> {
|
||||
Self::from_env(HuggingFaceModel::BgeBase)
|
||||
}
|
||||
|
||||
pub fn all_minilm() -> Result<Self, EmbeddingError> {
|
||||
Self::from_env(HuggingFaceModel::AllMiniLm)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "huggingface-provider")]
|
||||
#[async_trait]
|
||||
impl EmbeddingProvider for HuggingFaceProvider {
|
||||
fn name(&self) -> &str {
|
||||
"huggingface"
|
||||
}
|
||||
|
||||
fn model(&self) -> &str {
|
||||
self.model.model_id()
|
||||
}
|
||||
|
||||
fn dimensions(&self) -> usize {
|
||||
self.model.dimensions()
|
||||
}
|
||||
|
||||
fn is_local(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn max_tokens(&self) -> usize {
|
||||
// HuggingFace doesn't specify a hard limit, but most models handle ~512 tokens
|
||||
512
|
||||
}
|
||||
|
||||
fn max_batch_size(&self) -> usize {
|
||||
// HuggingFace Inference API doesn't support batch requests
|
||||
// Each request is individual
|
||||
1
|
||||
}
|
||||
|
||||
fn cost_per_1m_tokens(&self) -> f64 {
|
||||
// HuggingFace Inference API is free for public models
|
||||
// For dedicated endpoints, costs vary
|
||||
0.0
|
||||
}
|
||||
|
||||
async fn is_available(&self) -> bool {
|
||||
// HuggingFace Inference API is always available (with rate limits)
|
||||
true
|
||||
}
|
||||
|
||||
async fn embed(
|
||||
&self,
|
||||
text: &str,
|
||||
options: &EmbeddingOptions,
|
||||
) -> Result<Embedding, EmbeddingError> {
|
||||
if text.is_empty() {
|
||||
return Err(EmbeddingError::InvalidInput(
|
||||
"Text cannot be empty".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let url = format!("{}/{}", Self::BASE_URL, self.model.model_id());
|
||||
let request = HFRequest {
|
||||
inputs: text.to_string(),
|
||||
};
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
EmbeddingError::ApiError(format!("HuggingFace API request failed: {}", e))
|
||||
})?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let error_text = response.text().await.unwrap_or_default();
|
||||
return Err(EmbeddingError::ApiError(format!(
|
||||
"HuggingFace API error {}: {}",
|
||||
status, error_text
|
||||
)));
|
||||
}
|
||||
|
||||
let hf_response: HFResponse = response.json().await.map_err(|e| {
|
||||
EmbeddingError::ApiError(format!("Failed to parse HuggingFace response: {}", e))
|
||||
})?;
|
||||
|
||||
let mut embedding = match hf_response {
|
||||
HFResponse::Single(emb) => emb,
|
||||
HFResponse::Multiple(embs) => {
|
||||
if embs.is_empty() {
|
||||
return Err(EmbeddingError::ApiError(
|
||||
"Empty embeddings response from HuggingFace".to_string(),
|
||||
));
|
||||
}
|
||||
embs[0].clone()
|
||||
}
|
||||
};
|
||||
|
||||
// Validate dimensions
|
||||
if embedding.len() != self.dimensions() {
|
||||
return Err(EmbeddingError::DimensionMismatch {
|
||||
expected: self.dimensions(),
|
||||
actual: embedding.len(),
|
||||
});
|
||||
}
|
||||
|
||||
// Normalize if requested
|
||||
if options.normalize {
|
||||
normalize_embedding(&mut embedding);
|
||||
}
|
||||
|
||||
Ok(embedding)
|
||||
}
|
||||
|
||||
async fn embed_batch(
|
||||
&self,
|
||||
texts: &[&str],
|
||||
options: &EmbeddingOptions,
|
||||
) -> Result<EmbeddingResult, EmbeddingError> {
|
||||
if texts.is_empty() {
|
||||
return Err(EmbeddingError::InvalidInput(
|
||||
"Texts cannot be empty".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// HuggingFace Inference API doesn't support true batch requests
|
||||
// We need to send individual requests
|
||||
let mut embeddings = Vec::with_capacity(texts.len());
|
||||
|
||||
for text in texts {
|
||||
let embedding = self.embed(text, options).await?;
|
||||
embeddings.push(embedding);
|
||||
}
|
||||
|
||||
Ok(EmbeddingResult {
|
||||
embeddings,
|
||||
model: self.model().to_string(),
|
||||
dimensions: self.dimensions(),
|
||||
total_tokens: None,
|
||||
cached_count: 0,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "huggingface-provider"))]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_model_id_mapping() {
|
||||
assert_eq!(
|
||||
HuggingFaceModel::BgeSmall.model_id(),
|
||||
"BAAI/bge-small-en-v1.5"
|
||||
);
|
||||
assert_eq!(HuggingFaceModel::BgeSmall.dimensions(), 384);
|
||||
|
||||
assert_eq!(
|
||||
HuggingFaceModel::BgeBase.model_id(),
|
||||
"BAAI/bge-base-en-v1.5"
|
||||
);
|
||||
assert_eq!(HuggingFaceModel::BgeBase.dimensions(), 768);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_custom_model() {
|
||||
let custom = HuggingFaceModel::Custom("my-model".to_string(), 512);
|
||||
assert_eq!(custom.model_id(), "my-model");
|
||||
assert_eq!(custom.dimensions(), 512);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_model_id() {
|
||||
let model = HuggingFaceModel::from_model_id("BAAI/bge-small-en-v1.5", None);
|
||||
assert_eq!(model, HuggingFaceModel::BgeSmall);
|
||||
|
||||
let custom = HuggingFaceModel::from_model_id("unknown-model", Some(256));
|
||||
assert!(matches!(custom, HuggingFaceModel::Custom(_, 256)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_creation() {
|
||||
let provider = HuggingFaceProvider::new("test-key", HuggingFaceModel::BgeSmall);
|
||||
assert!(provider.is_ok());
|
||||
|
||||
let provider = provider.unwrap();
|
||||
assert_eq!(provider.name(), "huggingface");
|
||||
assert_eq!(provider.model(), "BAAI/bge-small-en-v1.5");
|
||||
assert_eq!(provider.dimensions(), 384);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_api_key() {
|
||||
let provider = HuggingFaceProvider::new("", HuggingFaceModel::BgeSmall);
|
||||
assert!(provider.is_err());
|
||||
assert!(matches!(
|
||||
provider.unwrap_err(),
|
||||
EmbeddingError::ConfigError(_)
|
||||
));
|
||||
}
|
||||
}
|
||||
25
crates/stratum-embeddings/src/providers/mod.rs
Normal file
25
crates/stratum-embeddings/src/providers/mod.rs
Normal file
@ -0,0 +1,25 @@
|
||||
#[cfg(feature = "cohere-provider")]
|
||||
pub mod cohere;
|
||||
#[cfg(feature = "fastembed-provider")]
|
||||
pub mod fastembed;
|
||||
#[cfg(feature = "huggingface-provider")]
|
||||
pub mod huggingface;
|
||||
#[cfg(feature = "ollama-provider")]
|
||||
pub mod ollama;
|
||||
#[cfg(feature = "openai-provider")]
|
||||
pub mod openai;
|
||||
#[cfg(feature = "voyage-provider")]
|
||||
pub mod voyage;
|
||||
|
||||
#[cfg(feature = "cohere-provider")]
|
||||
pub use cohere::{CohereModel, CohereProvider};
|
||||
#[cfg(feature = "fastembed-provider")]
|
||||
pub use fastembed::{FastEmbedModel, FastEmbedProvider};
|
||||
#[cfg(feature = "huggingface-provider")]
|
||||
pub use huggingface::{HuggingFaceModel, HuggingFaceProvider};
|
||||
#[cfg(feature = "ollama-provider")]
|
||||
pub use ollama::{OllamaModel, OllamaProvider};
|
||||
#[cfg(feature = "openai-provider")]
|
||||
pub use openai::{OpenAiModel, OpenAiProvider};
|
||||
#[cfg(feature = "voyage-provider")]
|
||||
pub use voyage::{VoyageModel, VoyageProvider};
|
||||
272
crates/stratum-embeddings/src/providers/ollama.rs
Normal file
272
crates/stratum-embeddings/src/providers/ollama.rs
Normal file
@ -0,0 +1,272 @@
|
||||
#[cfg(feature = "ollama-provider")]
|
||||
use std::time::Duration;
|
||||
|
||||
#[cfg(feature = "ollama-provider")]
|
||||
use async_trait::async_trait;
|
||||
#[cfg(feature = "ollama-provider")]
|
||||
use reqwest::Client;
|
||||
#[cfg(feature = "ollama-provider")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[cfg(feature = "ollama-provider")]
|
||||
use crate::{
|
||||
error::EmbeddingError,
|
||||
traits::{
|
||||
normalize_embedding, Embedding, EmbeddingOptions, EmbeddingProvider, EmbeddingResult,
|
||||
},
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Default)]
|
||||
pub enum OllamaModel {
|
||||
#[default]
|
||||
NomicEmbed,
|
||||
MxbaiEmbed,
|
||||
AllMiniLm,
|
||||
Custom(String, usize),
|
||||
}
|
||||
|
||||
impl OllamaModel {
|
||||
pub fn model_name(&self) -> &str {
|
||||
match self {
|
||||
Self::NomicEmbed => "nomic-embed-text",
|
||||
Self::MxbaiEmbed => "mxbai-embed-large",
|
||||
Self::AllMiniLm => "all-minilm",
|
||||
Self::Custom(name, _) => name,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dimensions(&self) -> usize {
|
||||
match self {
|
||||
Self::NomicEmbed => 768,
|
||||
Self::MxbaiEmbed => 1024,
|
||||
Self::AllMiniLm => 384,
|
||||
Self::Custom(_, dims) => *dims,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct OllamaRequest {
|
||||
model: String,
|
||||
input: OllamaInput,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
options: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(untagged)]
|
||||
enum OllamaInput {
|
||||
Single(String),
|
||||
Batch(Vec<String>),
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OllamaResponse {
|
||||
#[serde(default)]
|
||||
embeddings: Vec<Vec<f32>>,
|
||||
#[serde(default)]
|
||||
embedding: Option<Vec<f32>>,
|
||||
}
|
||||
|
||||
pub struct OllamaProvider {
|
||||
client: Client,
|
||||
model: OllamaModel,
|
||||
base_url: String,
|
||||
}
|
||||
|
||||
impl OllamaProvider {
|
||||
pub fn new(model: OllamaModel) -> Result<Self, EmbeddingError> {
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(120))
|
||||
.build()
|
||||
.map_err(|e| EmbeddingError::Initialization(e.to_string()))?;
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
model,
|
||||
base_url: "http://localhost:11434".to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn default_model() -> Result<Self, EmbeddingError> {
|
||||
Self::new(OllamaModel::default())
|
||||
}
|
||||
|
||||
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
|
||||
self.base_url = base_url.into();
|
||||
self
|
||||
}
|
||||
|
||||
pub fn custom_model(
|
||||
name: impl Into<String>,
|
||||
dimensions: usize,
|
||||
) -> Result<Self, EmbeddingError> {
|
||||
Self::new(OllamaModel::Custom(name.into(), dimensions))
|
||||
}
|
||||
|
||||
async fn call_api(&self, input: OllamaInput) -> Result<OllamaResponse, EmbeddingError> {
|
||||
let request = OllamaRequest {
|
||||
model: self.model.model_name().to_string(),
|
||||
input,
|
||||
options: None,
|
||||
};
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(format!("{}/api/embed", self.base_url))
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
if e.is_connect() {
|
||||
EmbeddingError::ProviderUnavailable(format!(
|
||||
"Cannot connect to Ollama at {}",
|
||||
self.base_url
|
||||
))
|
||||
} else {
|
||||
e.into()
|
||||
}
|
||||
})?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let error_text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(EmbeddingError::ApiError(format!(
|
||||
"Ollama API error {}: {}",
|
||||
status, error_text
|
||||
)));
|
||||
}
|
||||
|
||||
let api_response: OllamaResponse = response.json().await?;
|
||||
Ok(api_response)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl EmbeddingProvider for OllamaProvider {
|
||||
fn name(&self) -> &str {
|
||||
"ollama"
|
||||
}
|
||||
|
||||
fn model(&self) -> &str {
|
||||
self.model.model_name()
|
||||
}
|
||||
|
||||
fn dimensions(&self) -> usize {
|
||||
self.model.dimensions()
|
||||
}
|
||||
|
||||
fn is_local(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn max_tokens(&self) -> usize {
|
||||
2048
|
||||
}
|
||||
|
||||
fn max_batch_size(&self) -> usize {
|
||||
128
|
||||
}
|
||||
|
||||
fn cost_per_1m_tokens(&self) -> f64 {
|
||||
0.0
|
||||
}
|
||||
|
||||
async fn is_available(&self) -> bool {
|
||||
self.client
|
||||
.get(format!("{}/api/tags", self.base_url))
|
||||
.send()
|
||||
.await
|
||||
.map(|r| r.status().is_success())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
async fn embed(
|
||||
&self,
|
||||
text: &str,
|
||||
options: &EmbeddingOptions,
|
||||
) -> Result<Embedding, EmbeddingError> {
|
||||
if text.is_empty() {
|
||||
return Err(EmbeddingError::InvalidInput(
|
||||
"Text cannot be empty".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let response = self.call_api(OllamaInput::Single(text.to_string())).await?;
|
||||
|
||||
let mut embedding = if let Some(emb) = response.embedding {
|
||||
emb
|
||||
} else {
|
||||
response.embeddings.into_iter().next().ok_or_else(|| {
|
||||
EmbeddingError::ApiError("Ollama returned no embeddings".to_string())
|
||||
})?
|
||||
};
|
||||
|
||||
if options.normalize {
|
||||
normalize_embedding(&mut embedding);
|
||||
}
|
||||
|
||||
Ok(embedding)
|
||||
}
|
||||
|
||||
async fn embed_batch(
|
||||
&self,
|
||||
texts: &[&str],
|
||||
options: &EmbeddingOptions,
|
||||
) -> Result<EmbeddingResult, EmbeddingError> {
|
||||
if texts.is_empty() {
|
||||
return Err(EmbeddingError::InvalidInput(
|
||||
"Texts cannot be empty".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if texts.len() > self.max_batch_size() {
|
||||
return Err(EmbeddingError::BatchSizeExceeded {
|
||||
size: texts.len(),
|
||||
max: self.max_batch_size(),
|
||||
});
|
||||
}
|
||||
|
||||
let texts_owned: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
|
||||
let response = self.call_api(OllamaInput::Batch(texts_owned)).await?;
|
||||
|
||||
let mut embeddings = response.embeddings;
|
||||
|
||||
if embeddings.is_empty() {
|
||||
return Err(EmbeddingError::ApiError(
|
||||
"Ollama returned no embeddings".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if options.normalize {
|
||||
embeddings.iter_mut().for_each(normalize_embedding);
|
||||
}
|
||||
|
||||
Ok(EmbeddingResult {
|
||||
embeddings,
|
||||
model: self.model().to_string(),
|
||||
dimensions: self.dimensions(),
|
||||
total_tokens: None,
|
||||
cached_count: 0,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_ollama_model_metadata() {
|
||||
assert_eq!(OllamaModel::NomicEmbed.dimensions(), 768);
|
||||
assert_eq!(OllamaModel::MxbaiEmbed.dimensions(), 1024);
|
||||
assert_eq!(OllamaModel::AllMiniLm.dimensions(), 384);
|
||||
|
||||
let custom = OllamaModel::Custom("my-model".to_string(), 512);
|
||||
assert_eq!(custom.model_name(), "my-model");
|
||||
assert_eq!(custom.dimensions(), 512);
|
||||
}
|
||||
}
|
||||
286
crates/stratum-embeddings/src/providers/openai.rs
Normal file
286
crates/stratum-embeddings/src/providers/openai.rs
Normal file
@ -0,0 +1,286 @@
|
||||
#[cfg(feature = "openai-provider")]
|
||||
use std::time::Duration;
|
||||
|
||||
#[cfg(feature = "openai-provider")]
|
||||
use async_trait::async_trait;
|
||||
#[cfg(feature = "openai-provider")]
|
||||
use reqwest::Client;
|
||||
#[cfg(feature = "openai-provider")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[cfg(feature = "openai-provider")]
|
||||
use crate::{
|
||||
error::EmbeddingError,
|
||||
traits::{
|
||||
normalize_embedding, Embedding, EmbeddingOptions, EmbeddingProvider, EmbeddingResult,
|
||||
},
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Default)]
|
||||
pub enum OpenAiModel {
|
||||
#[default]
|
||||
TextEmbedding3Small,
|
||||
TextEmbedding3Large,
|
||||
TextEmbeddingAda002,
|
||||
}
|
||||
|
||||
impl OpenAiModel {
|
||||
pub fn model_name(&self) -> &'static str {
|
||||
match self {
|
||||
Self::TextEmbedding3Small => "text-embedding-3-small",
|
||||
Self::TextEmbedding3Large => "text-embedding-3-large",
|
||||
Self::TextEmbeddingAda002 => "text-embedding-ada-002",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dimensions(&self) -> usize {
|
||||
match self {
|
||||
Self::TextEmbedding3Small => 1536,
|
||||
Self::TextEmbedding3Large => 3072,
|
||||
Self::TextEmbeddingAda002 => 1536,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cost_per_1m(&self) -> f64 {
|
||||
match self {
|
||||
Self::TextEmbedding3Small => 0.02,
|
||||
Self::TextEmbedding3Large => 0.13,
|
||||
Self::TextEmbeddingAda002 => 0.10,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct OpenAiRequest {
|
||||
input: OpenAiInput,
|
||||
model: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
encoding_format: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(untagged)]
|
||||
enum OpenAiInput {
|
||||
Single(String),
|
||||
Batch(Vec<String>),
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenAiResponse {
|
||||
data: Vec<OpenAiEmbedding>,
|
||||
usage: OpenAiUsage,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenAiEmbedding {
|
||||
embedding: Vec<f32>,
|
||||
index: usize,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenAiUsage {
|
||||
total_tokens: u32,
|
||||
}
|
||||
|
||||
pub struct OpenAiProvider {
|
||||
client: Client,
|
||||
api_key: String,
|
||||
model: OpenAiModel,
|
||||
base_url: String,
|
||||
}
|
||||
|
||||
impl OpenAiProvider {
|
||||
pub fn new(api_key: impl Into<String>, model: OpenAiModel) -> Result<Self, EmbeddingError> {
|
||||
let api_key = api_key.into();
|
||||
if api_key.is_empty() {
|
||||
return Err(EmbeddingError::ConfigError(
|
||||
"OpenAI API key is empty".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(60))
|
||||
.build()
|
||||
.map_err(|e| EmbeddingError::Initialization(e.to_string()))?;
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
api_key,
|
||||
model,
|
||||
base_url: "https://api.openai.com/v1".to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn from_env(model: OpenAiModel) -> Result<Self, EmbeddingError> {
|
||||
let api_key = std::env::var("OPENAI_API_KEY")
|
||||
.map_err(|_| EmbeddingError::ConfigError("OPENAI_API_KEY not set".to_string()))?;
|
||||
Self::new(api_key, model)
|
||||
}
|
||||
|
||||
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
|
||||
self.base_url = base_url.into();
|
||||
self
|
||||
}
|
||||
|
||||
async fn call_api(&self, input: OpenAiInput) -> Result<OpenAiResponse, EmbeddingError> {
|
||||
let request = OpenAiRequest {
|
||||
input,
|
||||
model: self.model.model_name().to_string(),
|
||||
encoding_format: None,
|
||||
};
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(format!("{}/embeddings", self.base_url))
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let error_text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(EmbeddingError::ApiError(format!(
|
||||
"OpenAI API error {}: {}",
|
||||
status, error_text
|
||||
)));
|
||||
}
|
||||
|
||||
let api_response: OpenAiResponse = response.json().await?;
|
||||
Ok(api_response)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl EmbeddingProvider for OpenAiProvider {
|
||||
fn name(&self) -> &str {
|
||||
"openai"
|
||||
}
|
||||
|
||||
fn model(&self) -> &str {
|
||||
self.model.model_name()
|
||||
}
|
||||
|
||||
fn dimensions(&self) -> usize {
|
||||
self.model.dimensions()
|
||||
}
|
||||
|
||||
fn is_local(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn max_tokens(&self) -> usize {
|
||||
8191
|
||||
}
|
||||
|
||||
fn max_batch_size(&self) -> usize {
|
||||
2048
|
||||
}
|
||||
|
||||
fn cost_per_1m_tokens(&self) -> f64 {
|
||||
self.model.cost_per_1m()
|
||||
}
|
||||
|
||||
async fn is_available(&self) -> bool {
|
||||
self.client
|
||||
.get(format!("{}/models", self.base_url))
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.send()
|
||||
.await
|
||||
.map(|r| r.status().is_success())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
async fn embed(
|
||||
&self,
|
||||
text: &str,
|
||||
options: &EmbeddingOptions,
|
||||
) -> Result<Embedding, EmbeddingError> {
|
||||
if text.is_empty() {
|
||||
return Err(EmbeddingError::InvalidInput(
|
||||
"Text cannot be empty".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let response = self.call_api(OpenAiInput::Single(text.to_string())).await?;
|
||||
|
||||
let mut embedding = response
|
||||
.data
|
||||
.into_iter()
|
||||
.next()
|
||||
.ok_or_else(|| EmbeddingError::ApiError("OpenAI returned no embeddings".to_string()))?
|
||||
.embedding;
|
||||
|
||||
if options.normalize {
|
||||
normalize_embedding(&mut embedding);
|
||||
}
|
||||
|
||||
Ok(embedding)
|
||||
}
|
||||
|
||||
async fn embed_batch(
|
||||
&self,
|
||||
texts: &[&str],
|
||||
options: &EmbeddingOptions,
|
||||
) -> Result<EmbeddingResult, EmbeddingError> {
|
||||
if texts.is_empty() {
|
||||
return Err(EmbeddingError::InvalidInput(
|
||||
"Texts cannot be empty".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if texts.len() > self.max_batch_size() {
|
||||
return Err(EmbeddingError::BatchSizeExceeded {
|
||||
size: texts.len(),
|
||||
max: self.max_batch_size(),
|
||||
});
|
||||
}
|
||||
|
||||
let texts_owned: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
|
||||
let response = self.call_api(OpenAiInput::Batch(texts_owned)).await?;
|
||||
|
||||
let mut embeddings_with_index: Vec<_> = response
|
||||
.data
|
||||
.into_iter()
|
||||
.map(|e| (e.index, e.embedding))
|
||||
.collect();
|
||||
|
||||
embeddings_with_index.sort_by_key(|(idx, _)| *idx);
|
||||
|
||||
let mut embeddings: Vec<Embedding> = embeddings_with_index
|
||||
.into_iter()
|
||||
.map(|(_, emb)| emb)
|
||||
.collect();
|
||||
|
||||
if options.normalize {
|
||||
embeddings.iter_mut().for_each(normalize_embedding);
|
||||
}
|
||||
|
||||
Ok(EmbeddingResult {
|
||||
embeddings,
|
||||
model: self.model().to_string(),
|
||||
dimensions: self.dimensions(),
|
||||
total_tokens: Some(response.usage.total_tokens),
|
||||
cached_count: 0,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_openai_model_metadata() {
|
||||
assert_eq!(OpenAiModel::TextEmbedding3Small.dimensions(), 1536);
|
||||
assert_eq!(OpenAiModel::TextEmbedding3Large.dimensions(), 3072);
|
||||
assert_eq!(OpenAiModel::TextEmbeddingAda002.dimensions(), 1536);
|
||||
|
||||
assert_eq!(OpenAiModel::TextEmbedding3Small.cost_per_1m(), 0.02);
|
||||
assert_eq!(OpenAiModel::TextEmbedding3Large.cost_per_1m(), 0.13);
|
||||
}
|
||||
}
|
||||
259
crates/stratum-embeddings/src/providers/voyage.rs
Normal file
259
crates/stratum-embeddings/src/providers/voyage.rs
Normal file
@ -0,0 +1,259 @@
|
||||
#[cfg(feature = "voyage-provider")]
|
||||
use async_trait::async_trait;
|
||||
#[cfg(feature = "voyage-provider")]
|
||||
use reqwest::Client;
|
||||
#[cfg(feature = "voyage-provider")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[cfg(feature = "voyage-provider")]
|
||||
use crate::{
|
||||
error::EmbeddingError,
|
||||
traits::{Embedding, EmbeddingOptions, EmbeddingProvider, EmbeddingResult},
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Default)]
|
||||
pub enum VoyageModel {
|
||||
#[default]
|
||||
Voyage2,
|
||||
VoyageLarge2,
|
||||
VoyageCode2,
|
||||
VoyageLite02Instruct,
|
||||
}
|
||||
|
||||
impl VoyageModel {
|
||||
pub fn model_name(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Voyage2 => "voyage-2",
|
||||
Self::VoyageLarge2 => "voyage-large-2",
|
||||
Self::VoyageCode2 => "voyage-code-2",
|
||||
Self::VoyageLite02Instruct => "voyage-lite-02-instruct",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dimensions(&self) -> usize {
|
||||
match self {
|
||||
Self::Voyage2 => 1024,
|
||||
Self::VoyageLarge2 => 1536,
|
||||
Self::VoyageCode2 => 1536,
|
||||
Self::VoyageLite02Instruct => 1024,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn max_tokens(&self) -> usize {
|
||||
match self {
|
||||
Self::Voyage2 | Self::VoyageLarge2 | Self::VoyageCode2 => 16000,
|
||||
Self::VoyageLite02Instruct => 4000,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct VoyageEmbedRequest {
|
||||
input: Vec<String>,
|
||||
model: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
input_type: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
truncation: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct VoyageEmbedResponse {
|
||||
data: Vec<VoyageEmbeddingData>,
|
||||
usage: VoyageUsage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct VoyageEmbeddingData {
|
||||
embedding: Vec<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct VoyageUsage {
|
||||
total_tokens: u32,
|
||||
}
|
||||
|
||||
pub struct VoyageProvider {
|
||||
client: Client,
|
||||
api_key: String,
|
||||
model: VoyageModel,
|
||||
base_url: String,
|
||||
}
|
||||
|
||||
impl VoyageProvider {
|
||||
pub fn new(api_key: String, model: VoyageModel) -> Self {
|
||||
Self {
|
||||
client: Client::new(),
|
||||
api_key,
|
||||
model,
|
||||
base_url: "https://api.voyageai.com/v1".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_base_url(mut self, base_url: String) -> Self {
|
||||
self.base_url = base_url;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn voyage_2(api_key: String) -> Self {
|
||||
Self::new(api_key, VoyageModel::Voyage2)
|
||||
}
|
||||
|
||||
pub fn voyage_large_2(api_key: String) -> Self {
|
||||
Self::new(api_key, VoyageModel::VoyageLarge2)
|
||||
}
|
||||
|
||||
pub fn voyage_code_2(api_key: String) -> Self {
|
||||
Self::new(api_key, VoyageModel::VoyageCode2)
|
||||
}
|
||||
|
||||
async fn embed_batch_internal(
|
||||
&self,
|
||||
texts: &[&str],
|
||||
_options: &EmbeddingOptions,
|
||||
) -> Result<EmbeddingResult, EmbeddingError> {
|
||||
let request = VoyageEmbedRequest {
|
||||
input: texts.iter().map(|s| s.to_string()).collect(),
|
||||
model: self.model.model_name().to_string(),
|
||||
input_type: Some("document".to_string()),
|
||||
truncation: Some(true),
|
||||
};
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(format!("{}/embeddings", self.base_url))
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| EmbeddingError::ApiError(format!("Voyage API request failed: {}", e)))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let error_text = response.text().await.unwrap_or_default();
|
||||
return Err(EmbeddingError::ApiError(format!(
|
||||
"Voyage API error {}: {}",
|
||||
status, error_text
|
||||
)));
|
||||
}
|
||||
|
||||
let result: VoyageEmbedResponse = response.json().await.map_err(|e| {
|
||||
EmbeddingError::ApiError(format!("Failed to parse Voyage response: {}", e))
|
||||
})?;
|
||||
|
||||
let embeddings: Vec<Embedding> = result.data.into_iter().map(|d| d.embedding).collect();
|
||||
|
||||
Ok(EmbeddingResult {
|
||||
embeddings,
|
||||
model: self.model.model_name().to_string(),
|
||||
dimensions: self.model.dimensions(),
|
||||
total_tokens: Some(result.usage.total_tokens),
|
||||
cached_count: 0,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl EmbeddingProvider for VoyageProvider {
|
||||
fn name(&self) -> &str {
|
||||
"voyage"
|
||||
}
|
||||
|
||||
fn model(&self) -> &str {
|
||||
self.model.model_name()
|
||||
}
|
||||
|
||||
fn dimensions(&self) -> usize {
|
||||
self.model.dimensions()
|
||||
}
|
||||
|
||||
fn is_local(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn max_tokens(&self) -> usize {
|
||||
self.model.max_tokens()
|
||||
}
|
||||
|
||||
fn max_batch_size(&self) -> usize {
|
||||
128
|
||||
}
|
||||
|
||||
fn cost_per_1m_tokens(&self) -> f64 {
|
||||
match self.model {
|
||||
VoyageModel::Voyage2 => 0.10,
|
||||
VoyageModel::VoyageLarge2 => 0.12,
|
||||
VoyageModel::VoyageCode2 => 0.12,
|
||||
VoyageModel::VoyageLite02Instruct => 0.06,
|
||||
}
|
||||
}
|
||||
|
||||
async fn is_available(&self) -> bool {
|
||||
!self.api_key.is_empty()
|
||||
}
|
||||
|
||||
async fn embed(
|
||||
&self,
|
||||
text: &str,
|
||||
options: &EmbeddingOptions,
|
||||
) -> Result<Embedding, EmbeddingError> {
|
||||
let result = self.embed_batch_internal(&[text], options).await?;
|
||||
result
|
||||
.embeddings
|
||||
.into_iter()
|
||||
.next()
|
||||
.ok_or_else(|| EmbeddingError::ApiError("No embedding returned".to_string()))
|
||||
}
|
||||
|
||||
async fn embed_batch(
|
||||
&self,
|
||||
texts: &[&str],
|
||||
options: &EmbeddingOptions,
|
||||
) -> Result<EmbeddingResult, EmbeddingError> {
|
||||
if texts.is_empty() {
|
||||
return Err(EmbeddingError::InvalidInput(
|
||||
"Cannot embed empty text list".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if texts.len() > self.max_batch_size() {
|
||||
return Err(EmbeddingError::InvalidInput(format!(
|
||||
"Batch size {} exceeds maximum {}",
|
||||
texts.len(),
|
||||
self.max_batch_size()
|
||||
)));
|
||||
}
|
||||
|
||||
self.embed_batch_internal(texts, options).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_voyage_model_names() {
|
||||
assert_eq!(VoyageModel::Voyage2.model_name(), "voyage-2");
|
||||
assert_eq!(VoyageModel::Voyage2.dimensions(), 1024);
|
||||
assert_eq!(VoyageModel::VoyageLarge2.dimensions(), 1536);
|
||||
assert_eq!(VoyageModel::VoyageCode2.dimensions(), 1536);
|
||||
assert_eq!(VoyageModel::VoyageLite02Instruct.dimensions(), 1024);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_voyage_max_tokens() {
|
||||
assert_eq!(VoyageModel::Voyage2.max_tokens(), 16000);
|
||||
assert_eq!(VoyageModel::VoyageLite02Instruct.max_tokens(), 4000);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_voyage_provider_creation() {
|
||||
let provider = VoyageProvider::new("test-key".to_string(), VoyageModel::Voyage2);
|
||||
assert_eq!(provider.name(), "voyage");
|
||||
assert_eq!(provider.model(), "voyage-2");
|
||||
assert_eq!(provider.dimensions(), 1024);
|
||||
assert_eq!(provider.max_batch_size(), 128);
|
||||
}
|
||||
}
|
||||
283
crates/stratum-embeddings/src/service.rs
Normal file
283
crates/stratum-embeddings/src/service.rs
Normal file
@ -0,0 +1,283 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::{
|
||||
batch::BatchProcessor,
|
||||
cache::EmbeddingCache,
|
||||
error::EmbeddingError,
|
||||
traits::{Embedding, EmbeddingOptions, EmbeddingProvider, EmbeddingResult, ProviderInfo},
|
||||
};
|
||||
|
||||
pub struct EmbeddingService<P: EmbeddingProvider, C: EmbeddingCache> {
|
||||
provider: Arc<P>,
|
||||
cache: Option<Arc<C>>,
|
||||
fallback_providers: Vec<Arc<dyn EmbeddingProvider>>,
|
||||
batch_processor: BatchProcessor<P, C>,
|
||||
}
|
||||
|
||||
impl<P: EmbeddingProvider + 'static, C: EmbeddingCache + 'static> EmbeddingService<P, C> {
|
||||
pub fn new(provider: P) -> Self {
|
||||
let provider = Arc::new(provider);
|
||||
let batch_processor = BatchProcessor::new(Arc::clone(&provider), None);
|
||||
|
||||
Self {
|
||||
provider,
|
||||
cache: None,
|
||||
fallback_providers: Vec::new(),
|
||||
batch_processor,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_cache(mut self, cache: C) -> Self {
|
||||
let cache = Arc::new(cache);
|
||||
self.cache = Some(Arc::clone(&cache));
|
||||
self.batch_processor = BatchProcessor::new(Arc::clone(&self.provider), Some(cache));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_fallback(mut self, fallback: Arc<dyn EmbeddingProvider>) -> Self {
|
||||
self.fallback_providers.push(fallback);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_batch_concurrency(mut self, max_concurrent: usize) -> Self {
|
||||
self.batch_processor = self.batch_processor.with_concurrency(max_concurrent);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn provider_info(&self) -> ProviderInfo {
|
||||
ProviderInfo::from(self.provider.as_ref())
|
||||
}
|
||||
|
||||
pub async fn is_ready(&self) -> bool {
|
||||
self.provider.is_available().await
|
||||
}
|
||||
|
||||
pub async fn embed(
|
||||
&self,
|
||||
text: &str,
|
||||
options: &EmbeddingOptions,
|
||||
) -> Result<Embedding, EmbeddingError> {
|
||||
if text.is_empty() {
|
||||
return Err(EmbeddingError::InvalidInput(
|
||||
"Text cannot be empty".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if options.use_cache {
|
||||
if let Some(cache) = &self.cache {
|
||||
let key =
|
||||
crate::cache::cache_key(self.provider.name(), self.provider.model(), text);
|
||||
if let Some(cached) = cache.get(&key).await {
|
||||
debug!("Cache hit for text (len={})", text.len());
|
||||
return Ok(cached);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let result = self.embed_with_fallback(text, options).await?;
|
||||
|
||||
if options.use_cache {
|
||||
if let Some(cache) = &self.cache {
|
||||
let key =
|
||||
crate::cache::cache_key(self.provider.name(), self.provider.model(), text);
|
||||
cache.insert(&key, result.clone()).await;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub async fn embed_batch(
|
||||
&self,
|
||||
texts: Vec<String>,
|
||||
options: &EmbeddingOptions,
|
||||
) -> Result<EmbeddingResult, EmbeddingError> {
|
||||
if texts.is_empty() {
|
||||
return Err(EmbeddingError::InvalidInput(
|
||||
"Texts cannot be empty".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
info!("Processing batch of {} texts", texts.len());
|
||||
self.batch_processor.process_batch(texts, options).await
|
||||
}
|
||||
|
||||
pub async fn embed_stream(
|
||||
&self,
|
||||
texts: Vec<String>,
|
||||
options: &EmbeddingOptions,
|
||||
) -> Result<Vec<Embedding>, EmbeddingError> {
|
||||
self.batch_processor.process_stream(texts, options).await
|
||||
}
|
||||
|
||||
async fn embed_with_fallback(
|
||||
&self,
|
||||
text: &str,
|
||||
options: &EmbeddingOptions,
|
||||
) -> Result<Embedding, EmbeddingError> {
|
||||
match self.provider.embed(text, options).await {
|
||||
Ok(embedding) => Ok(embedding),
|
||||
Err(e) => {
|
||||
warn!("Primary provider failed: {}", e);
|
||||
|
||||
for (idx, fallback) in self.fallback_providers.iter().enumerate() {
|
||||
debug!("Trying fallback provider {}", idx);
|
||||
|
||||
if !fallback.is_available().await {
|
||||
warn!("Fallback provider {} not available", idx);
|
||||
continue;
|
||||
}
|
||||
|
||||
match fallback.embed(text, options).await {
|
||||
Ok(embedding) => {
|
||||
info!("Fallback provider {} succeeded", idx);
|
||||
return Ok(embedding);
|
||||
}
|
||||
Err(fallback_err) => {
|
||||
warn!("Fallback provider {} failed: {}", idx, fallback_err);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn invalidate_cache(&self, text: &str) -> Result<(), EmbeddingError> {
|
||||
if let Some(cache) = &self.cache {
|
||||
let key = crate::cache::cache_key(self.provider.name(), self.provider.model(), text);
|
||||
cache.invalidate(&key).await;
|
||||
Ok(())
|
||||
} else {
|
||||
Err(EmbeddingError::CacheError(
|
||||
"No cache configured".to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn clear_cache(&self) -> Result<(), EmbeddingError> {
|
||||
if let Some(cache) = &self.cache {
|
||||
cache.clear().await;
|
||||
Ok(())
|
||||
} else {
|
||||
Err(EmbeddingError::CacheError(
|
||||
"No cache configured".to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cache_size(&self) -> usize {
|
||||
self.cache.as_ref().map(|c| c.size()).unwrap_or(0)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::time::Duration;
|
||||
|
||||
use super::*;
|
||||
use crate::{cache::MemoryCache, providers::FastEmbedProvider};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_service_basic() {
|
||||
let provider = FastEmbedProvider::small().expect("Failed to init");
|
||||
let service: EmbeddingService<_, MemoryCache> = EmbeddingService::new(provider);
|
||||
|
||||
let options = EmbeddingOptions::no_cache();
|
||||
let embedding = service
|
||||
.embed("Hello world", &options)
|
||||
.await
|
||||
.expect("Failed to embed");
|
||||
|
||||
assert_eq!(embedding.len(), 384);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_service_with_cache() {
|
||||
let provider = FastEmbedProvider::small().expect("Failed to init");
|
||||
let cache = MemoryCache::new(100, Duration::from_secs(60));
|
||||
let service = EmbeddingService::new(provider).with_cache(cache);
|
||||
|
||||
let options = EmbeddingOptions::default_with_cache();
|
||||
|
||||
let embedding1 = service
|
||||
.embed("Hello world", &options)
|
||||
.await
|
||||
.expect("Failed first embed");
|
||||
assert_eq!(service.cache_size(), 1);
|
||||
|
||||
let embedding2 = service
|
||||
.embed("Hello world", &options)
|
||||
.await
|
||||
.expect("Failed second embed");
|
||||
assert_eq!(embedding1, embedding2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_service_batch() {
|
||||
let provider = FastEmbedProvider::small().expect("Failed to init");
|
||||
let cache = MemoryCache::with_defaults();
|
||||
let service = EmbeddingService::new(provider).with_cache(cache);
|
||||
|
||||
let texts = vec![
|
||||
"Text 1".to_string(),
|
||||
"Text 2".to_string(),
|
||||
"Text 3".to_string(),
|
||||
];
|
||||
|
||||
let options = EmbeddingOptions::default_with_cache();
|
||||
let result = service
|
||||
.embed_batch(texts.clone(), &options)
|
||||
.await
|
||||
.expect("Failed batch");
|
||||
|
||||
assert_eq!(result.embeddings.len(), 3);
|
||||
assert_eq!(result.cached_count, 0);
|
||||
|
||||
let result2 = service
|
||||
.embed_batch(texts, &options)
|
||||
.await
|
||||
.expect("Failed second batch");
|
||||
assert_eq!(result2.cached_count, 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_service_cache_invalidation() {
|
||||
let provider = FastEmbedProvider::small().expect("Failed to init");
|
||||
let cache = MemoryCache::with_defaults();
|
||||
let service = EmbeddingService::new(provider).with_cache(cache);
|
||||
|
||||
let options = EmbeddingOptions::default_with_cache();
|
||||
|
||||
service
|
||||
.embed("Test text", &options)
|
||||
.await
|
||||
.expect("Failed embed");
|
||||
assert_eq!(service.cache_size(), 1);
|
||||
|
||||
service
|
||||
.invalidate_cache("Test text")
|
||||
.await
|
||||
.expect("Failed to invalidate");
|
||||
assert_eq!(service.cache_size(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_service_clear_cache() {
|
||||
let provider = FastEmbedProvider::small().expect("Failed to init");
|
||||
let cache = MemoryCache::with_defaults();
|
||||
let service = EmbeddingService::new(provider).with_cache(cache);
|
||||
|
||||
let options = EmbeddingOptions::default_with_cache();
|
||||
|
||||
service.embed("Test 1", &options).await.expect("Failed");
|
||||
service.embed("Test 2", &options).await.expect("Failed");
|
||||
assert_eq!(service.cache_size(), 2);
|
||||
|
||||
service.clear_cache().await.expect("Failed to clear");
|
||||
assert_eq!(service.cache_size(), 0);
|
||||
}
|
||||
}
|
||||
430
crates/stratum-embeddings/src/store/lancedb.rs
Normal file
430
crates/stratum-embeddings/src/store/lancedb.rs
Normal file
@ -0,0 +1,430 @@
|
||||
#[cfg(feature = "lancedb-store")]
|
||||
use std::sync::Arc;
|
||||
|
||||
#[cfg(feature = "lancedb-store")]
|
||||
use arrow::{
|
||||
array::{
|
||||
AsArray, FixedSizeListArray, Float32Array, RecordBatch, RecordBatchIterator, StringArray,
|
||||
},
|
||||
datatypes::{DataType, Field, Float32Type, Schema},
|
||||
};
|
||||
#[cfg(feature = "lancedb-store")]
|
||||
use async_trait::async_trait;
|
||||
#[cfg(feature = "lancedb-store")]
|
||||
use futures::TryStreamExt;
|
||||
#[cfg(feature = "lancedb-store")]
|
||||
use lancedb::{query::ExecutableQuery, query::QueryBase, Connection, DistanceType, Table};
|
||||
|
||||
#[cfg(feature = "lancedb-store")]
|
||||
use crate::{
|
||||
error::EmbeddingError,
|
||||
store::{SearchFilter, SearchResult, VectorStore, VectorStoreConfig},
|
||||
traits::Embedding,
|
||||
};
|
||||
|
||||
pub struct LanceDbStore {
|
||||
#[allow(dead_code)]
|
||||
connection: Connection,
|
||||
table: Table,
|
||||
dimensions: usize,
|
||||
}
|
||||
|
||||
impl LanceDbStore {
|
||||
pub async fn new(
|
||||
path: &str,
|
||||
table_name: &str,
|
||||
config: VectorStoreConfig,
|
||||
) -> Result<Self, EmbeddingError> {
|
||||
let connection = lancedb::connect(path).execute().await.map_err(|e| {
|
||||
EmbeddingError::Initialization(format!("LanceDB connection failed: {}", e))
|
||||
})?;
|
||||
|
||||
let table = match connection.open_table(table_name).execute().await {
|
||||
Ok(t) => t,
|
||||
Err(_) => {
|
||||
let schema = Self::create_schema(config.dimensions);
|
||||
let empty_batch = RecordBatch::new_empty(Arc::clone(&schema));
|
||||
let batches = RecordBatchIterator::new(
|
||||
vec![Ok(empty_batch)].into_iter(),
|
||||
Arc::clone(&schema),
|
||||
);
|
||||
|
||||
connection
|
||||
.create_table(table_name, batches)
|
||||
.execute()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
EmbeddingError::Initialization(format!("Failed to create table: {}", e))
|
||||
})?
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
connection,
|
||||
table,
|
||||
dimensions: config.dimensions,
|
||||
})
|
||||
}
|
||||
|
||||
fn create_schema(dimensions: usize) -> Arc<Schema> {
|
||||
Arc::new(Schema::new(vec![
|
||||
Field::new("id", DataType::Utf8, false),
|
||||
Field::new(
|
||||
"vector",
|
||||
DataType::FixedSizeList(
|
||||
Arc::new(Field::new("item", DataType::Float32, true)),
|
||||
dimensions as i32,
|
||||
),
|
||||
false,
|
||||
),
|
||||
Field::new("metadata", DataType::Utf8, true),
|
||||
]))
|
||||
}
|
||||
|
||||
fn create_record_batch(
|
||||
ids: Vec<String>,
|
||||
embeddings: Vec<Embedding>,
|
||||
metadata: Vec<String>,
|
||||
dimensions: usize,
|
||||
) -> Result<RecordBatch, EmbeddingError> {
|
||||
let id_array = Arc::new(StringArray::from(ids));
|
||||
let metadata_array = Arc::new(StringArray::from(metadata));
|
||||
|
||||
let flat_values: Vec<f32> = embeddings.into_iter().flatten().collect();
|
||||
let values_array = Arc::new(Float32Array::from(flat_values));
|
||||
let vector_array = Arc::new(FixedSizeListArray::new(
|
||||
Arc::new(Field::new("item", DataType::Float32, true)),
|
||||
dimensions as i32,
|
||||
values_array,
|
||||
None,
|
||||
));
|
||||
|
||||
let schema = Self::create_schema(dimensions);
|
||||
|
||||
RecordBatch::try_new(schema, vec![id_array, vector_array, metadata_array]).map_err(|e| {
|
||||
EmbeddingError::StoreError(format!("Failed to create record batch: {}", e))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl VectorStore for LanceDbStore {
|
||||
fn name(&self) -> &str {
|
||||
"lancedb"
|
||||
}
|
||||
|
||||
fn dimensions(&self) -> usize {
|
||||
self.dimensions
|
||||
}
|
||||
|
||||
async fn upsert(
|
||||
&self,
|
||||
id: &str,
|
||||
embedding: &Embedding,
|
||||
metadata: serde_json::Value,
|
||||
) -> Result<(), EmbeddingError> {
|
||||
if embedding.len() != self.dimensions {
|
||||
return Err(EmbeddingError::DimensionMismatch {
|
||||
expected: self.dimensions,
|
||||
actual: embedding.len(),
|
||||
});
|
||||
}
|
||||
|
||||
let metadata_str = serde_json::to_string(&metadata)
|
||||
.map_err(|e| EmbeddingError::SerializationError(e.to_string()))?;
|
||||
|
||||
let batch = Self::create_record_batch(
|
||||
vec![id.to_string()],
|
||||
vec![embedding.clone()],
|
||||
vec![metadata_str],
|
||||
self.dimensions,
|
||||
)?;
|
||||
|
||||
let schema = Self::create_schema(self.dimensions);
|
||||
let batches = RecordBatchIterator::new(vec![Ok(batch)].into_iter(), schema);
|
||||
|
||||
self.table
|
||||
.add(batches)
|
||||
.execute()
|
||||
.await
|
||||
.map_err(|e| EmbeddingError::StoreError(format!("Failed to upsert: {}", e)))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn upsert_batch(
|
||||
&self,
|
||||
items: Vec<(String, Embedding, serde_json::Value)>,
|
||||
) -> Result<(), EmbeddingError> {
|
||||
if items.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
for (_, embedding, _) in &items {
|
||||
if embedding.len() != self.dimensions {
|
||||
return Err(EmbeddingError::DimensionMismatch {
|
||||
expected: self.dimensions,
|
||||
actual: embedding.len(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let (ids, embeddings, metadata): (Vec<_>, Vec<_>, Vec<_>) = items.into_iter().fold(
|
||||
(Vec::new(), Vec::new(), Vec::new()),
|
||||
|(mut ids, mut embeddings, mut metadata), (id, embedding, meta)| {
|
||||
ids.push(id);
|
||||
embeddings.push(embedding);
|
||||
metadata.push(serde_json::to_string(&meta).unwrap_or_default());
|
||||
(ids, embeddings, metadata)
|
||||
},
|
||||
);
|
||||
|
||||
let batch = Self::create_record_batch(ids, embeddings, metadata, self.dimensions)?;
|
||||
|
||||
let schema = Self::create_schema(self.dimensions);
|
||||
let batches = RecordBatchIterator::new(vec![Ok(batch)].into_iter(), schema);
|
||||
|
||||
self.table
|
||||
.add(batches)
|
||||
.execute()
|
||||
.await
|
||||
.map_err(|e| EmbeddingError::StoreError(format!("Failed to batch upsert: {}", e)))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn search(
|
||||
&self,
|
||||
embedding: &Embedding,
|
||||
limit: usize,
|
||||
filter: Option<SearchFilter>,
|
||||
) -> Result<Vec<SearchResult>, EmbeddingError> {
|
||||
if embedding.len() != self.dimensions {
|
||||
return Err(EmbeddingError::DimensionMismatch {
|
||||
expected: self.dimensions,
|
||||
actual: embedding.len(),
|
||||
});
|
||||
}
|
||||
|
||||
let mut query = self
|
||||
.table
|
||||
.vector_search(embedding.as_slice())
|
||||
.map_err(|e| EmbeddingError::StoreError(format!("Query setup failed: {}", e)))?
|
||||
.distance_type(DistanceType::Cosine);
|
||||
|
||||
if let Some(ref f) = filter {
|
||||
if f.min_score.is_some() {
|
||||
query = query.postfilter().refine_factor(10);
|
||||
}
|
||||
}
|
||||
|
||||
let stream =
|
||||
query.limit(limit).execute().await.map_err(|e| {
|
||||
EmbeddingError::StoreError(format!("Query execution failed: {}", e))
|
||||
})?;
|
||||
|
||||
let batches: Vec<RecordBatch> = stream
|
||||
.try_collect()
|
||||
.await
|
||||
.map_err(|e| EmbeddingError::StoreError(format!("Failed to collect results: {}", e)))?;
|
||||
|
||||
let mut search_results = Vec::new();
|
||||
|
||||
for batch in batches.iter() {
|
||||
let ids: &Arc<dyn arrow::array::Array> =
|
||||
batch.column_by_name("id").ok_or_else(|| {
|
||||
EmbeddingError::StoreError("Missing 'id' column in results".to_string())
|
||||
})?;
|
||||
|
||||
let metadata_col: &Arc<dyn arrow::array::Array> =
|
||||
batch.column_by_name("metadata").ok_or_else(|| {
|
||||
EmbeddingError::StoreError("Missing 'metadata' column in results".to_string())
|
||||
})?;
|
||||
|
||||
let distance_col: Option<&Arc<dyn arrow::array::Array>> =
|
||||
batch.column_by_name("_distance");
|
||||
|
||||
let id_array = ids.as_string::<i32>();
|
||||
let metadata_array = metadata_col.as_string::<i32>();
|
||||
|
||||
let num_rows: usize = batch.num_rows();
|
||||
for i in 0..num_rows {
|
||||
let id = id_array.value(i).to_string();
|
||||
|
||||
let metadata_str = metadata_array.value(i);
|
||||
let metadata: serde_json::Value =
|
||||
serde_json::from_str(metadata_str).unwrap_or(serde_json::json!({}));
|
||||
|
||||
let score = if let Some(dist_col) = distance_col {
|
||||
let dist_array = dist_col.as_primitive::<Float32Type>();
|
||||
1.0 - dist_array.value(i)
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
search_results.push(SearchResult {
|
||||
id,
|
||||
score,
|
||||
embedding: None,
|
||||
metadata,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(f) = filter {
|
||||
if let Some(min_score) = f.min_score {
|
||||
search_results.retain(|r| r.score >= min_score);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(search_results)
|
||||
}
|
||||
|
||||
async fn get(&self, id: &str) -> Result<Option<SearchResult>, EmbeddingError> {
|
||||
let stream = self
|
||||
.table
|
||||
.query()
|
||||
.only_if(format!("id = '{}'", id))
|
||||
.limit(1)
|
||||
.execute()
|
||||
.await
|
||||
.map_err(|e| EmbeddingError::StoreError(format!("Query execution failed: {}", e)))?;
|
||||
|
||||
let results: Vec<RecordBatch> = stream
|
||||
.try_collect()
|
||||
.await
|
||||
.map_err(|e| EmbeddingError::StoreError(format!("Failed to collect results: {}", e)))?;
|
||||
|
||||
for batch in results.iter() {
|
||||
let num_rows: usize = batch.num_rows();
|
||||
if num_rows == 0 {
|
||||
continue;
|
||||
}
|
||||
|
||||
let ids: &Arc<dyn arrow::array::Array> =
|
||||
batch.column_by_name("id").ok_or_else(|| {
|
||||
EmbeddingError::StoreError("Missing 'id' column in results".to_string())
|
||||
})?;
|
||||
|
||||
let metadata_col: &Arc<dyn arrow::array::Array> =
|
||||
batch.column_by_name("metadata").ok_or_else(|| {
|
||||
EmbeddingError::StoreError("Missing 'metadata' column in results".to_string())
|
||||
})?;
|
||||
|
||||
let id_array = ids.as_string::<i32>();
|
||||
let metadata_array = metadata_col.as_string::<i32>();
|
||||
|
||||
let result_id = id_array.value(0).to_string();
|
||||
let metadata_str = metadata_array.value(0);
|
||||
let metadata: serde_json::Value =
|
||||
serde_json::from_str(metadata_str).unwrap_or(serde_json::json!({}));
|
||||
|
||||
return Ok(Some(SearchResult {
|
||||
id: result_id,
|
||||
score: 1.0,
|
||||
embedding: None,
|
||||
metadata,
|
||||
}));
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn delete(&self, id: &str) -> Result<bool, EmbeddingError> {
|
||||
self.table
|
||||
.delete(&format!("id = '{}'", id))
|
||||
.await
|
||||
.map_err(|e| EmbeddingError::StoreError(format!("Delete failed: {}", e)))?;
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
async fn flush(&self) -> Result<(), EmbeddingError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn count(&self) -> Result<usize, EmbeddingError> {
|
||||
let count = self
|
||||
.table
|
||||
.count_rows(None)
|
||||
.await
|
||||
.map_err(|e| EmbeddingError::StoreError(format!("Count failed: {}", e)))?;
|
||||
|
||||
Ok(count)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use tempfile::tempdir;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_lancedb_store_basic() {
|
||||
let dir = tempdir().expect("Failed to create temp dir");
|
||||
let path = dir.path().to_str().unwrap();
|
||||
|
||||
let config = VectorStoreConfig::new(384);
|
||||
let store = LanceDbStore::new(path, "test_embeddings", config)
|
||||
.await
|
||||
.expect("Failed to create store");
|
||||
|
||||
let embedding = vec![0.1; 384];
|
||||
let metadata = serde_json::json!({"text": "hello world"});
|
||||
|
||||
store
|
||||
.upsert("test_id", &embedding, metadata.clone())
|
||||
.await
|
||||
.expect("Failed to upsert");
|
||||
|
||||
let result = store.get("test_id").await.expect("Failed to get");
|
||||
assert!(result.is_some());
|
||||
|
||||
let search_results = store
|
||||
.search(&embedding, 5, None)
|
||||
.await
|
||||
.expect("Failed to search");
|
||||
assert!(!search_results.is_empty());
|
||||
|
||||
let count = store.count().await.expect("Failed to count");
|
||||
assert_eq!(count, 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_lancedb_store_batch() {
|
||||
let dir = tempdir().expect("Failed to create temp dir");
|
||||
let path = dir.path().to_str().unwrap();
|
||||
|
||||
let config = VectorStoreConfig::new(384);
|
||||
let store = LanceDbStore::new(path, "test_batch", config)
|
||||
.await
|
||||
.expect("Failed to create store");
|
||||
|
||||
let items = vec![
|
||||
(
|
||||
"id1".to_string(),
|
||||
vec![0.1; 384],
|
||||
serde_json::json!({"idx": 1}),
|
||||
),
|
||||
(
|
||||
"id2".to_string(),
|
||||
vec![0.2; 384],
|
||||
serde_json::json!({"idx": 2}),
|
||||
),
|
||||
(
|
||||
"id3".to_string(),
|
||||
vec![0.3; 384],
|
||||
serde_json::json!({"idx": 3}),
|
||||
),
|
||||
];
|
||||
|
||||
store
|
||||
.upsert_batch(items)
|
||||
.await
|
||||
.expect("Failed to batch upsert");
|
||||
|
||||
let count = store.count().await.expect("Failed to count");
|
||||
assert_eq!(count, 3);
|
||||
}
|
||||
}
|
||||
14
crates/stratum-embeddings/src/store/mod.rs
Normal file
14
crates/stratum-embeddings/src/store/mod.rs
Normal file
@ -0,0 +1,14 @@
|
||||
pub mod traits;
|
||||
|
||||
#[cfg(feature = "lancedb-store")]
|
||||
pub mod lancedb;
|
||||
|
||||
#[cfg(feature = "surrealdb-store")]
|
||||
pub mod surrealdb;
|
||||
|
||||
pub use traits::*;
|
||||
|
||||
#[cfg(feature = "lancedb-store")]
|
||||
pub use self::lancedb::LanceDbStore;
|
||||
#[cfg(feature = "surrealdb-store")]
|
||||
pub use self::surrealdb::SurrealDbStore;
|
||||
298
crates/stratum-embeddings/src/store/surrealdb.rs
Normal file
298
crates/stratum-embeddings/src/store/surrealdb.rs
Normal file
@ -0,0 +1,298 @@
|
||||
#[cfg(feature = "surrealdb-store")]
|
||||
use async_trait::async_trait;
|
||||
#[cfg(feature = "surrealdb-store")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
#[cfg(feature = "surrealdb-store")]
|
||||
use surrealdb::{
|
||||
engine::local::{Db, Mem},
|
||||
sql::Thing,
|
||||
Surreal,
|
||||
};
|
||||
|
||||
#[cfg(feature = "surrealdb-store")]
|
||||
use crate::{
|
||||
error::EmbeddingError,
|
||||
store::{SearchFilter, SearchResult, VectorStore, VectorStoreConfig},
|
||||
traits::Embedding,
|
||||
};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct EmbeddingRecord {
|
||||
id: Option<Thing>,
|
||||
vector: Vec<f32>,
|
||||
metadata: serde_json::Value,
|
||||
}
|
||||
|
||||
pub struct SurrealDbStore {
|
||||
db: Surreal<Db>,
|
||||
table: String,
|
||||
dimensions: usize,
|
||||
}
|
||||
|
||||
impl SurrealDbStore {
|
||||
pub async fn new(
|
||||
connection: Surreal<Db>,
|
||||
table_name: &str,
|
||||
config: VectorStoreConfig,
|
||||
) -> Result<Self, EmbeddingError> {
|
||||
Ok(Self {
|
||||
db: connection,
|
||||
table: table_name.to_string(),
|
||||
dimensions: config.dimensions,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn new_memory(
|
||||
table_name: &str,
|
||||
config: VectorStoreConfig,
|
||||
) -> Result<Self, EmbeddingError> {
|
||||
let db = Surreal::new::<Mem>(()).await.map_err(|e| {
|
||||
EmbeddingError::Initialization(format!("SurrealDB connection failed: {}", e))
|
||||
})?;
|
||||
|
||||
db.use_ns("embeddings")
|
||||
.use_db("embeddings")
|
||||
.await
|
||||
.map_err(|e| {
|
||||
EmbeddingError::Initialization(format!("Failed to set namespace: {}", e))
|
||||
})?;
|
||||
|
||||
Self::new(db, table_name, config).await
|
||||
}
|
||||
|
||||
fn compute_cosine_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
|
||||
if a.len() != b.len() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
|
||||
let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
|
||||
if mag_a > 0.0 && mag_b > 0.0 {
|
||||
dot / (mag_a * mag_b)
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl VectorStore for SurrealDbStore {
|
||||
fn name(&self) -> &str {
|
||||
"surrealdb"
|
||||
}
|
||||
|
||||
fn dimensions(&self) -> usize {
|
||||
self.dimensions
|
||||
}
|
||||
|
||||
async fn upsert(
|
||||
&self,
|
||||
id: &str,
|
||||
embedding: &Embedding,
|
||||
metadata: serde_json::Value,
|
||||
) -> Result<(), EmbeddingError> {
|
||||
if embedding.len() != self.dimensions {
|
||||
return Err(EmbeddingError::DimensionMismatch {
|
||||
expected: self.dimensions,
|
||||
actual: embedding.len(),
|
||||
});
|
||||
}
|
||||
|
||||
let record = EmbeddingRecord {
|
||||
id: None,
|
||||
vector: embedding.clone(),
|
||||
metadata,
|
||||
};
|
||||
|
||||
let _: Option<EmbeddingRecord> = self
|
||||
.db
|
||||
.update((self.table.as_str(), id))
|
||||
.content(record)
|
||||
.await
|
||||
.map_err(|e| EmbeddingError::StoreError(format!("Upsert failed: {}", e)))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn search(
|
||||
&self,
|
||||
embedding: &Embedding,
|
||||
limit: usize,
|
||||
filter: Option<SearchFilter>,
|
||||
) -> Result<Vec<SearchResult>, EmbeddingError> {
|
||||
if embedding.len() != self.dimensions {
|
||||
return Err(EmbeddingError::DimensionMismatch {
|
||||
expected: self.dimensions,
|
||||
actual: embedding.len(),
|
||||
});
|
||||
}
|
||||
|
||||
let all_records: Vec<EmbeddingRecord> = self
|
||||
.db
|
||||
.select(&self.table)
|
||||
.await
|
||||
.map_err(|e| EmbeddingError::StoreError(format!("Search failed: {}", e)))?;
|
||||
|
||||
let mut scored_results: Vec<(String, f32, serde_json::Value)> = all_records
|
||||
.into_iter()
|
||||
.filter_map(|record| {
|
||||
let id = record.id?.id.to_string();
|
||||
let score = self.compute_cosine_similarity(embedding, &record.vector);
|
||||
Some((id, score, record.metadata))
|
||||
})
|
||||
.collect();
|
||||
|
||||
scored_results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||
|
||||
if let Some(f) = filter {
|
||||
if let Some(min_score) = f.min_score {
|
||||
scored_results.retain(|(_, score, _)| *score >= min_score);
|
||||
}
|
||||
}
|
||||
|
||||
let results = scored_results
|
||||
.into_iter()
|
||||
.take(limit)
|
||||
.map(|(id, score, metadata)| SearchResult {
|
||||
id,
|
||||
score,
|
||||
embedding: None,
|
||||
metadata,
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
async fn get(&self, id: &str) -> Result<Option<SearchResult>, EmbeddingError> {
|
||||
let record: Option<EmbeddingRecord> = self
|
||||
.db
|
||||
.select((self.table.as_str(), id))
|
||||
.await
|
||||
.map_err(|e| EmbeddingError::StoreError(format!("Get failed: {}", e)))?;
|
||||
|
||||
Ok(record.map(|r| SearchResult {
|
||||
id: id.to_string(),
|
||||
score: 1.0,
|
||||
embedding: Some(r.vector),
|
||||
metadata: r.metadata,
|
||||
}))
|
||||
}
|
||||
|
||||
async fn delete(&self, id: &str) -> Result<bool, EmbeddingError> {
|
||||
let result: Option<EmbeddingRecord> = self
|
||||
.db
|
||||
.delete((self.table.as_str(), id))
|
||||
.await
|
||||
.map_err(|e| EmbeddingError::StoreError(format!("Delete failed: {}", e)))?;
|
||||
|
||||
Ok(result.is_some())
|
||||
}
|
||||
|
||||
async fn flush(&self) -> Result<(), EmbeddingError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn count(&self) -> Result<usize, EmbeddingError> {
|
||||
let records: Vec<EmbeddingRecord> = self
|
||||
.db
|
||||
.select(&self.table)
|
||||
.await
|
||||
.map_err(|e| EmbeddingError::StoreError(format!("Count failed: {}", e)))?;
|
||||
|
||||
Ok(records.len())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_surrealdb_store_basic() {
|
||||
let config = VectorStoreConfig::new(384);
|
||||
let store = SurrealDbStore::new_memory("test_embeddings", config)
|
||||
.await
|
||||
.expect("Failed to create store");
|
||||
|
||||
let embedding = vec![0.1; 384];
|
||||
let metadata = serde_json::json!({"text": "hello world"});
|
||||
|
||||
store
|
||||
.upsert("test_id", &embedding, metadata.clone())
|
||||
.await
|
||||
.expect("Failed to upsert");
|
||||
|
||||
let result = store.get("test_id").await.expect("Failed to get");
|
||||
assert!(result.is_some());
|
||||
assert_eq!(result.as_ref().unwrap().id, "test_id");
|
||||
assert_eq!(result.as_ref().unwrap().metadata, metadata);
|
||||
|
||||
let search_results = store
|
||||
.search(&embedding, 5, None)
|
||||
.await
|
||||
.expect("Failed to search");
|
||||
assert_eq!(search_results.len(), 1);
|
||||
assert!(search_results[0].score > 0.99);
|
||||
|
||||
let count = store.count().await.expect("Failed to count");
|
||||
assert_eq!(count, 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_surrealdb_store_search() {
|
||||
let config = VectorStoreConfig::new(3);
|
||||
let store = SurrealDbStore::new_memory("test_search", config)
|
||||
.await
|
||||
.expect("Failed to create store");
|
||||
|
||||
let embeddings = [
|
||||
vec![1.0, 0.0, 0.0],
|
||||
vec![0.0, 1.0, 0.0],
|
||||
vec![0.5, 0.5, 0.0],
|
||||
];
|
||||
|
||||
for (i, embedding) in embeddings.iter().enumerate() {
|
||||
store
|
||||
.upsert(
|
||||
&format!("id_{}", i),
|
||||
embedding,
|
||||
serde_json::json!({"idx": i}),
|
||||
)
|
||||
.await
|
||||
.expect("Failed to upsert");
|
||||
}
|
||||
|
||||
let query = vec![1.0, 0.0, 0.0];
|
||||
let results = store
|
||||
.search(&query, 2, None)
|
||||
.await
|
||||
.expect("Failed to search");
|
||||
|
||||
assert_eq!(results.len(), 2);
|
||||
assert_eq!(results[0].id, "id_0");
|
||||
assert!(results[0].score > 0.99);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_surrealdb_store_delete() {
|
||||
let config = VectorStoreConfig::new(384);
|
||||
let store = SurrealDbStore::new_memory("test_delete", config)
|
||||
.await
|
||||
.expect("Failed to create store");
|
||||
|
||||
let embedding = vec![0.1; 384];
|
||||
store
|
||||
.upsert("test_id", &embedding, serde_json::json!({}))
|
||||
.await
|
||||
.expect("Failed to upsert");
|
||||
|
||||
let deleted = store.delete("test_id").await.expect("Failed to delete");
|
||||
assert!(deleted);
|
||||
|
||||
let result = store.get("test_id").await.expect("Failed to get");
|
||||
assert!(result.is_none());
|
||||
}
|
||||
}
|
||||
102
crates/stratum-embeddings/src/store/traits.rs
Normal file
102
crates/stratum-embeddings/src/store/traits.rs
Normal file
@ -0,0 +1,102 @@
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{error::EmbeddingError, traits::Embedding};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SearchResult {
|
||||
pub id: String,
|
||||
pub score: f32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub embedding: Option<Embedding>,
|
||||
pub metadata: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct SearchFilter {
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
pub min_score: Option<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct VectorStoreConfig {
|
||||
pub dimensions: usize,
|
||||
pub metric: DistanceMetric,
|
||||
pub options: serde_json::Value,
|
||||
}
|
||||
|
||||
impl VectorStoreConfig {
|
||||
pub fn new(dimensions: usize) -> Self {
|
||||
Self {
|
||||
dimensions,
|
||||
metric: DistanceMetric::default(),
|
||||
options: serde_json::json!({}),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_metric(mut self, metric: DistanceMetric) -> Self {
|
||||
self.metric = metric;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_options(mut self, options: serde_json::Value) -> Self {
|
||||
self.options = options;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
pub enum DistanceMetric {
|
||||
#[default]
|
||||
Cosine,
|
||||
Euclidean,
|
||||
DotProduct,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait VectorStore: Send + Sync {
|
||||
fn name(&self) -> &str;
|
||||
fn dimensions(&self) -> usize;
|
||||
|
||||
async fn upsert(
|
||||
&self,
|
||||
id: &str,
|
||||
embedding: &Embedding,
|
||||
metadata: serde_json::Value,
|
||||
) -> Result<(), EmbeddingError>;
|
||||
|
||||
async fn upsert_batch(
|
||||
&self,
|
||||
items: Vec<(String, Embedding, serde_json::Value)>,
|
||||
) -> Result<(), EmbeddingError> {
|
||||
for (id, embedding, metadata) in items {
|
||||
self.upsert(&id, &embedding, metadata).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn search(
|
||||
&self,
|
||||
embedding: &Embedding,
|
||||
limit: usize,
|
||||
filter: Option<SearchFilter>,
|
||||
) -> Result<Vec<SearchResult>, EmbeddingError>;
|
||||
|
||||
async fn get(&self, id: &str) -> Result<Option<SearchResult>, EmbeddingError>;
|
||||
|
||||
async fn delete(&self, id: &str) -> Result<bool, EmbeddingError>;
|
||||
|
||||
async fn delete_batch(&self, ids: &[&str]) -> Result<usize, EmbeddingError> {
|
||||
let mut count = 0;
|
||||
for id in ids {
|
||||
if self.delete(id).await? {
|
||||
count += 1;
|
||||
}
|
||||
}
|
||||
Ok(count)
|
||||
}
|
||||
|
||||
async fn flush(&self) -> Result<(), EmbeddingError>;
|
||||
|
||||
async fn count(&self) -> Result<usize, EmbeddingError>;
|
||||
}
|
||||
162
crates/stratum-embeddings/src/traits.rs
Normal file
162
crates/stratum-embeddings/src/traits.rs
Normal file
@ -0,0 +1,162 @@
|
||||
use async_trait::async_trait;
|
||||
|
||||
pub type Embedding = Vec<f32>;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EmbeddingResult {
|
||||
pub embeddings: Vec<Embedding>,
|
||||
pub model: String,
|
||||
pub dimensions: usize,
|
||||
pub total_tokens: Option<u32>,
|
||||
pub cached_count: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct EmbeddingOptions {
|
||||
pub normalize: bool,
|
||||
pub truncate: bool,
|
||||
pub use_cache: bool,
|
||||
}
|
||||
|
||||
impl EmbeddingOptions {
|
||||
pub fn default_with_cache() -> Self {
|
||||
Self {
|
||||
normalize: true,
|
||||
truncate: true,
|
||||
use_cache: true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn no_cache() -> Self {
|
||||
Self {
|
||||
normalize: true,
|
||||
truncate: true,
|
||||
use_cache: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait EmbeddingProvider: Send + Sync {
|
||||
fn name(&self) -> &str;
|
||||
fn model(&self) -> &str;
|
||||
fn dimensions(&self) -> usize;
|
||||
fn is_local(&self) -> bool;
|
||||
fn max_tokens(&self) -> usize;
|
||||
fn max_batch_size(&self) -> usize;
|
||||
fn cost_per_1m_tokens(&self) -> f64;
|
||||
|
||||
async fn is_available(&self) -> bool;
|
||||
|
||||
async fn embed(
|
||||
&self,
|
||||
text: &str,
|
||||
options: &EmbeddingOptions,
|
||||
) -> Result<Embedding, crate::error::EmbeddingError>;
|
||||
|
||||
async fn embed_batch(
|
||||
&self,
|
||||
texts: &[&str],
|
||||
options: &EmbeddingOptions,
|
||||
) -> Result<EmbeddingResult, crate::error::EmbeddingError>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ProviderInfo {
|
||||
pub name: String,
|
||||
pub model: String,
|
||||
pub dimensions: usize,
|
||||
pub is_local: bool,
|
||||
pub cost_per_1m: f64,
|
||||
pub max_batch_size: usize,
|
||||
}
|
||||
|
||||
impl<T: EmbeddingProvider> From<&T> for ProviderInfo {
|
||||
fn from(provider: &T) -> Self {
|
||||
Self {
|
||||
name: provider.name().to_string(),
|
||||
model: provider.model().to_string(),
|
||||
dimensions: provider.dimensions(),
|
||||
is_local: provider.is_local(),
|
||||
cost_per_1m: provider.cost_per_1m_tokens(),
|
||||
max_batch_size: provider.max_batch_size(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn normalize_embedding(embedding: &mut Embedding) {
|
||||
let magnitude: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
if magnitude > 0.0 {
|
||||
embedding.iter_mut().for_each(|x| *x /= magnitude);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cosine_similarity(a: &Embedding, b: &Embedding) -> f32 {
|
||||
if a.len() != b.len() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
|
||||
let magnitude_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let magnitude_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
|
||||
if magnitude_a > 0.0 && magnitude_b > 0.0 {
|
||||
dot_product / (magnitude_a * magnitude_b)
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
|
||||
pub fn euclidean_distance(a: &Embedding, b: &Embedding) -> f32 {
|
||||
if a.len() != b.len() {
|
||||
return f32::MAX;
|
||||
}
|
||||
|
||||
a.iter()
|
||||
.zip(b.iter())
|
||||
.map(|(x, y)| (x - y).powi(2))
|
||||
.sum::<f32>()
|
||||
.sqrt()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use approx::assert_relative_eq;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_normalize_embedding() {
|
||||
let mut embedding = vec![3.0, 4.0];
|
||||
normalize_embedding(&mut embedding);
|
||||
assert_relative_eq!(embedding[0], 0.6, epsilon = 0.0001);
|
||||
assert_relative_eq!(embedding[1], 0.8, epsilon = 0.0001);
|
||||
|
||||
let magnitude: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
assert_relative_eq!(magnitude, 1.0, epsilon = 0.0001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cosine_similarity() {
|
||||
let a = vec![1.0, 0.0, 0.0];
|
||||
let b = vec![1.0, 0.0, 0.0];
|
||||
assert_relative_eq!(cosine_similarity(&a, &b), 1.0, epsilon = 0.0001);
|
||||
|
||||
let c = vec![0.0, 1.0, 0.0];
|
||||
assert_relative_eq!(cosine_similarity(&a, &c), 0.0, epsilon = 0.0001);
|
||||
|
||||
let d = vec![-1.0, 0.0, 0.0];
|
||||
assert_relative_eq!(cosine_similarity(&a, &d), -1.0, epsilon = 0.0001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_euclidean_distance() {
|
||||
let a = vec![0.0, 0.0];
|
||||
let b = vec![3.0, 4.0];
|
||||
assert_relative_eq!(euclidean_distance(&a, &b), 5.0, epsilon = 0.0001);
|
||||
|
||||
let c = vec![1.0, 1.0];
|
||||
let d = vec![1.0, 1.0];
|
||||
assert_relative_eq!(euclidean_distance(&c, &d), 0.0, epsilon = 0.0001);
|
||||
}
|
||||
}
|
||||
65
crates/stratum-llm/Cargo.toml
Normal file
65
crates/stratum-llm/Cargo.toml
Normal file
@ -0,0 +1,65 @@
|
||||
[package]
|
||||
name = "stratum-llm"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
description = "Unified LLM abstraction with CLI detection, fallback, and caching"
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
# Async runtime
|
||||
tokio = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
|
||||
# HTTP client
|
||||
reqwest = { workspace = true, features = ["stream"] }
|
||||
|
||||
# Serialization
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
serde_yaml = { workspace = true, optional = true }
|
||||
|
||||
# Caching
|
||||
moka = { workspace = true }
|
||||
|
||||
# Error handling
|
||||
thiserror = { workspace = true }
|
||||
|
||||
# Logging
|
||||
tracing = { workspace = true }
|
||||
|
||||
# Metrics
|
||||
prometheus = { workspace = true, optional = true }
|
||||
|
||||
# Utilities
|
||||
dirs = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
uuid = { workspace = true, features = ["v4"] }
|
||||
which = { workspace = true, optional = true }
|
||||
|
||||
# Hashing for cache keys
|
||||
xxhash-rust = { workspace = true }
|
||||
|
||||
[features]
|
||||
default = ["anthropic", "openai", "ollama"]
|
||||
anthropic = []
|
||||
openai = []
|
||||
deepseek = []
|
||||
ollama = []
|
||||
claude-cli = []
|
||||
openai-cli = []
|
||||
kogral = ["serde_yaml", "which"]
|
||||
metrics = ["prometheus"]
|
||||
all = [
|
||||
"anthropic",
|
||||
"openai",
|
||||
"deepseek",
|
||||
"ollama",
|
||||
"claude-cli",
|
||||
"openai-cli",
|
||||
"kogral",
|
||||
"metrics",
|
||||
]
|
||||
|
||||
[dev-dependencies]
|
||||
tokio-test = { workspace = true }
|
||||
131
crates/stratum-llm/README.md
Normal file
131
crates/stratum-llm/README.md
Normal file
@ -0,0 +1,131 @@
|
||||
# stratum-llm
|
||||
|
||||
Unified LLM abstraction for the stratumiops ecosystem with automatic provider detection, fallback chains, and smart caching.
|
||||
|
||||
## Features
|
||||
|
||||
- **Credential Auto-detection**: Automatically finds CLI credentials (Claude, OpenAI) and API keys
|
||||
- **Provider Fallback**: Circuit breaker pattern with automatic failover across providers
|
||||
- **Smart Caching**: xxHash-based request deduplication reduces duplicate API calls
|
||||
- **Kogral Integration**: Inject project context from knowledge base (optional)
|
||||
- **Cost Tracking**: Transparent cost estimation across all providers
|
||||
- **Multiple Providers**: Anthropic Claude, OpenAI, DeepSeek, Ollama
|
||||
|
||||
## Quick Start
|
||||
|
||||
```rust
|
||||
use stratum_llm::{UnifiedClient, Message, Role};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let client = UnifiedClient::auto()?;
|
||||
|
||||
let messages = vec![
|
||||
Message {
|
||||
role: Role::User,
|
||||
content: "What is Rust?".to_string(),
|
||||
}
|
||||
];
|
||||
|
||||
let response = client.generate(&messages, None).await?;
|
||||
println!("{}", response.content);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
## Provider Priority
|
||||
|
||||
1. **CLI credentials** (subscription-based, no per-token cost) - preferred
|
||||
2. **API keys** from environment variables
|
||||
3. **Local models** (Ollama)
|
||||
|
||||
The client automatically detects available credentials and builds a fallback chain.
|
||||
|
||||
## Features
|
||||
|
||||
### Default Features
|
||||
|
||||
```toml
|
||||
[dependencies]
|
||||
stratum-llm = "0.1"
|
||||
```
|
||||
|
||||
Includes: Anthropic, OpenAI, Ollama
|
||||
|
||||
### All Features
|
||||
|
||||
```toml
|
||||
[dependencies]
|
||||
stratum-llm = { version = "0.1", features = ["all"] }
|
||||
```
|
||||
|
||||
Includes: All providers, CLI detection, Kogral integration, Prometheus metrics
|
||||
|
||||
### Custom Feature Set
|
||||
|
||||
```toml
|
||||
[dependencies]
|
||||
stratum-llm = { version = "0.1", features = ["anthropic", "deepseek", "kogral"] }
|
||||
```
|
||||
|
||||
Available features:
|
||||
|
||||
- `anthropic` - Anthropic Claude API
|
||||
- `openai` - OpenAI API
|
||||
- `deepseek` - DeepSeek API
|
||||
- `ollama` - Ollama local models
|
||||
- `claude-cli` - Claude CLI credential detection
|
||||
- `kogral` - Kogral knowledge base integration
|
||||
- `metrics` - Prometheus metrics
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
### With Kogral Context
|
||||
|
||||
```rust
|
||||
let client = UnifiedClient::builder()
|
||||
.auto_detect()?
|
||||
.with_kogral()
|
||||
.build()?;
|
||||
|
||||
let response = client
|
||||
.generate_with_kogral(&messages, None, Some("rust"), None)
|
||||
.await?;
|
||||
```
|
||||
|
||||
### Custom Fallback Strategy
|
||||
|
||||
```rust
|
||||
use stratum_llm::{FallbackStrategy, ProviderChain};
|
||||
|
||||
let chain = ProviderChain::from_detected()?
|
||||
.with_strategy(FallbackStrategy::OnRateLimitOrUnavailable);
|
||||
|
||||
let client = UnifiedClient::builder()
|
||||
.with_chain(chain)
|
||||
.build()?;
|
||||
```
|
||||
|
||||
### Cost Budget
|
||||
|
||||
```rust
|
||||
let chain = ProviderChain::from_detected()?
|
||||
.with_strategy(FallbackStrategy::OnBudgetExceeded {
|
||||
budget_cents: 10.0,
|
||||
});
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
Run examples with:
|
||||
|
||||
```bash
|
||||
cargo run --example basic_usage
|
||||
cargo run --example with_kogral --features kogral
|
||||
cargo run --example fallback_demo
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
MIT OR Apache-2.0
|
||||
43
crates/stratum-llm/examples/basic_usage.rs
Normal file
43
crates/stratum-llm/examples/basic_usage.rs
Normal file
@ -0,0 +1,43 @@
|
||||
use stratum_llm::{Message, Role, UnifiedClient};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
println!("Creating UnifiedClient with auto-detected providers...");
|
||||
let client = UnifiedClient::auto()?;
|
||||
|
||||
println!("\nAvailable providers:");
|
||||
for provider in client.providers() {
|
||||
println!(
|
||||
" - {} ({}): circuit={:?}, subscription={}",
|
||||
provider.name, provider.model, provider.circuit_state, provider.is_subscription
|
||||
);
|
||||
}
|
||||
|
||||
let messages = vec![Message {
|
||||
role: Role::User,
|
||||
content: "What is the capital of France? Answer in one word.".to_string(),
|
||||
}];
|
||||
|
||||
println!("\nSending request...");
|
||||
match client.generate(&messages, None).await {
|
||||
Ok(response) => {
|
||||
println!("\n✓ Success!");
|
||||
println!("Provider: {}", response.provider);
|
||||
println!("Model: {}", response.model);
|
||||
println!("Response: {}", response.content);
|
||||
println!(
|
||||
"Tokens: {} in, {} out",
|
||||
response.input_tokens, response.output_tokens
|
||||
);
|
||||
println!("Cost: ${:.4}", response.cost_cents / 100.0);
|
||||
println!("Latency: {}ms", response.latency_ms);
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("\n✗ Error: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
54
crates/stratum-llm/examples/fallback_demo.rs
Normal file
54
crates/stratum-llm/examples/fallback_demo.rs
Normal file
@ -0,0 +1,54 @@
|
||||
use stratum_llm::{FallbackStrategy, Message, Role, UnifiedClient};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
println!("Creating UnifiedClient with budget-based fallback strategy...");
|
||||
|
||||
let client = UnifiedClient::builder().auto_detect()?.build()?;
|
||||
|
||||
let messages = vec![Message {
|
||||
role: Role::User,
|
||||
content: "Explain quantum computing in simple terms.".to_string(),
|
||||
}];
|
||||
|
||||
println!("\nProvider chain:");
|
||||
for (idx, provider) in client.providers().iter().enumerate() {
|
||||
println!(
|
||||
" {}. {} ({}) - circuit: {:?}",
|
||||
idx + 1,
|
||||
provider.name,
|
||||
provider.model,
|
||||
provider.circuit_state
|
||||
);
|
||||
}
|
||||
|
||||
println!("\nSending multiple requests to test fallback...");
|
||||
|
||||
for i in 1..=3 {
|
||||
println!("\n--- Request {} ---", i);
|
||||
match client.generate(&messages, None).await {
|
||||
Ok(response) => {
|
||||
println!(
|
||||
"✓ Provider: {} | Model: {} | Cost: ${:.4} | Latency: {}ms",
|
||||
response.provider,
|
||||
response.model,
|
||||
response.cost_cents / 100.0,
|
||||
response.latency_ms
|
||||
);
|
||||
println!(
|
||||
"Response preview: {}...",
|
||||
&response.content[..100.min(response.content.len())]
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("✗ All providers failed: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
45
crates/stratum-llm/examples/with_kogral.rs
Normal file
45
crates/stratum-llm/examples/with_kogral.rs
Normal file
@ -0,0 +1,45 @@
|
||||
#[cfg(feature = "kogral")]
|
||||
use stratum_llm::{Message, Role, UnifiedClient};
|
||||
|
||||
#[cfg(feature = "kogral")]
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
println!("Creating UnifiedClient with Kogral integration...");
|
||||
let client = UnifiedClient::builder()
|
||||
.auto_detect()?
|
||||
.with_kogral()
|
||||
.build()?;
|
||||
|
||||
let messages = vec![Message {
|
||||
role: Role::User,
|
||||
content: "Write a simple Rust function to add two numbers.".to_string(),
|
||||
}];
|
||||
|
||||
println!("\nSending request with Rust guidelines from Kogral...");
|
||||
match client
|
||||
.generate_with_kogral(&messages, None, Some("rust"), None)
|
||||
.await
|
||||
{
|
||||
Ok(response) => {
|
||||
println!("\n✓ Success!");
|
||||
println!("Provider: {}", response.provider);
|
||||
println!("Model: {}", response.model);
|
||||
println!("Response:\n{}", response.content);
|
||||
println!("\nCost: ${:.4}", response.cost_cents / 100.0);
|
||||
println!("Latency: {}ms", response.latency_ms);
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("\n✗ Error: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "kogral"))]
|
||||
fn main() {
|
||||
eprintln!("This example requires the 'kogral' feature.");
|
||||
eprintln!("Run with: cargo run --example with_kogral --features kogral");
|
||||
}
|
||||
3
crates/stratum-llm/src/cache/mod.rs
vendored
Normal file
3
crates/stratum-llm/src/cache/mod.rs
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
pub mod request_cache;
|
||||
|
||||
pub use request_cache::{CacheConfig, CacheStats, CachedResponse, RequestCache};
|
||||
151
crates/stratum-llm/src/cache/request_cache.rs
vendored
Normal file
151
crates/stratum-llm/src/cache/request_cache.rs
vendored
Normal file
@ -0,0 +1,151 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use moka::future::Cache;
|
||||
use xxhash_rust::xxh3::xxh3_64;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CachedResponse {
|
||||
pub content: String,
|
||||
pub model: String,
|
||||
pub provider: String,
|
||||
pub cached_at: chrono::DateTime<chrono::Utc>,
|
||||
}
|
||||
|
||||
pub struct RequestCache {
|
||||
cache: Cache<u64, CachedResponse>,
|
||||
enabled: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CacheConfig {
|
||||
pub enabled: bool,
|
||||
pub max_entries: u64,
|
||||
pub ttl: Duration,
|
||||
}
|
||||
|
||||
impl Default for CacheConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: true,
|
||||
max_entries: 1000,
|
||||
ttl: Duration::from_secs(3600),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RequestCache {
|
||||
pub fn new(config: CacheConfig) -> Self {
|
||||
let cache = Cache::builder()
|
||||
.max_capacity(config.max_entries)
|
||||
.time_to_live(config.ttl)
|
||||
.build();
|
||||
|
||||
Self {
|
||||
cache,
|
||||
enabled: config.enabled,
|
||||
}
|
||||
}
|
||||
|
||||
fn compute_key(
|
||||
&self,
|
||||
messages: &[crate::providers::Message],
|
||||
options: &crate::providers::GenerationOptions,
|
||||
) -> u64 {
|
||||
let mut hasher_input = String::new();
|
||||
|
||||
for msg in messages {
|
||||
hasher_input.push_str(&format!("{:?}:{}\\n", msg.role, msg.content));
|
||||
}
|
||||
|
||||
hasher_input.push_str(&format!(
|
||||
"temp:{:?}|max:{:?}|top_p:{:?}",
|
||||
options.temperature, options.max_tokens, options.top_p,
|
||||
));
|
||||
|
||||
xxh3_64(hasher_input.as_bytes())
|
||||
}
|
||||
|
||||
pub async fn get(
|
||||
&self,
|
||||
messages: &[crate::providers::Message],
|
||||
options: &crate::providers::GenerationOptions,
|
||||
) -> Option<CachedResponse> {
|
||||
if !self.enabled {
|
||||
return None;
|
||||
}
|
||||
|
||||
let key = self.compute_key(messages, options);
|
||||
self.cache.get(&key).await
|
||||
}
|
||||
|
||||
pub async fn put(
|
||||
&self,
|
||||
messages: &[crate::providers::Message],
|
||||
options: &crate::providers::GenerationOptions,
|
||||
response: &crate::providers::GenerationResponse,
|
||||
) {
|
||||
if !self.enabled {
|
||||
return;
|
||||
}
|
||||
|
||||
let key = self.compute_key(messages, options);
|
||||
let cached = CachedResponse {
|
||||
content: response.content.clone(),
|
||||
model: response.model.clone(),
|
||||
provider: response.provider.clone(),
|
||||
cached_at: chrono::Utc::now(),
|
||||
};
|
||||
|
||||
self.cache.insert(key, cached).await;
|
||||
}
|
||||
|
||||
pub async fn get_or_generate<F, Fut>(
|
||||
&self,
|
||||
messages: &[crate::providers::Message],
|
||||
options: &crate::providers::GenerationOptions,
|
||||
generate: F,
|
||||
) -> Result<crate::providers::GenerationResponse, crate::error::LlmError>
|
||||
where
|
||||
F: FnOnce() -> Fut,
|
||||
Fut: std::future::Future<
|
||||
Output = Result<crate::providers::GenerationResponse, crate::error::LlmError>,
|
||||
>,
|
||||
{
|
||||
if let Some(cached) = self.get(messages, options).await {
|
||||
tracing::debug!("Cache hit");
|
||||
return Ok(crate::providers::GenerationResponse {
|
||||
content: cached.content,
|
||||
model: cached.model,
|
||||
provider: cached.provider,
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
cost_cents: 0.0,
|
||||
latency_ms: 0,
|
||||
});
|
||||
}
|
||||
|
||||
let response = generate().await?;
|
||||
self.put(messages, options, &response).await;
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
pub fn stats(&self) -> CacheStats {
|
||||
CacheStats {
|
||||
entry_count: self.cache.entry_count(),
|
||||
hit_count: 0,
|
||||
miss_count: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn clear(&self) {
|
||||
self.cache.invalidate_all();
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CacheStats {
|
||||
pub entry_count: u64,
|
||||
pub hit_count: u64,
|
||||
pub miss_count: u64,
|
||||
}
|
||||
146
crates/stratum-llm/src/chain/circuit_breaker.rs
Normal file
146
crates/stratum-llm/src/chain/circuit_breaker.rs
Normal file
@ -0,0 +1,146 @@
|
||||
use std::sync::atomic::{AtomicU32, AtomicU8, Ordering};
|
||||
use std::sync::RwLock;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum CircuitState {
|
||||
Closed = 0,
|
||||
Open = 1,
|
||||
HalfOpen = 2,
|
||||
}
|
||||
|
||||
pub struct CircuitBreaker {
|
||||
state: AtomicU8,
|
||||
failure_count: AtomicU32,
|
||||
success_count: AtomicU32,
|
||||
config: CircuitBreakerConfig,
|
||||
last_failure_time: RwLock<Option<Instant>>,
|
||||
last_success_time: RwLock<Option<Instant>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CircuitBreakerConfig {
|
||||
pub failure_threshold: u32,
|
||||
pub success_threshold: u32,
|
||||
pub reset_timeout: Duration,
|
||||
pub request_timeout: Duration,
|
||||
}
|
||||
|
||||
impl Default for CircuitBreakerConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
failure_threshold: 5,
|
||||
success_threshold: 3,
|
||||
reset_timeout: Duration::from_secs(30),
|
||||
request_timeout: Duration::from_secs(60),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CircuitBreaker {
|
||||
pub fn new(config: CircuitBreakerConfig) -> Self {
|
||||
Self {
|
||||
state: AtomicU8::new(CircuitState::Closed as u8),
|
||||
failure_count: AtomicU32::new(0),
|
||||
success_count: AtomicU32::new(0),
|
||||
config,
|
||||
last_failure_time: RwLock::new(None),
|
||||
last_success_time: RwLock::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn state(&self) -> CircuitState {
|
||||
match self.state.load(Ordering::SeqCst) {
|
||||
0 => CircuitState::Closed,
|
||||
1 => CircuitState::Open,
|
||||
2 => CircuitState::HalfOpen,
|
||||
_ => CircuitState::Closed,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn should_allow(&self) -> bool {
|
||||
match self.state() {
|
||||
CircuitState::Closed => true,
|
||||
CircuitState::Open => {
|
||||
if let Some(last_failure) = *self.last_failure_time.read().unwrap() {
|
||||
if last_failure.elapsed() >= self.config.reset_timeout {
|
||||
self.state
|
||||
.store(CircuitState::HalfOpen as u8, Ordering::SeqCst);
|
||||
self.success_count.store(0, Ordering::SeqCst);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
CircuitState::HalfOpen => true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn record_success(&self) {
|
||||
*self.last_success_time.write().unwrap() = Some(Instant::now());
|
||||
self.failure_count.store(0, Ordering::SeqCst);
|
||||
|
||||
if self.state() == CircuitState::HalfOpen {
|
||||
let count = self.success_count.fetch_add(1, Ordering::SeqCst) + 1;
|
||||
if count >= self.config.success_threshold {
|
||||
self.state
|
||||
.store(CircuitState::Closed as u8, Ordering::SeqCst);
|
||||
tracing::info!("Circuit breaker closed (recovered)");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn record_failure(&self) {
|
||||
*self.last_failure_time.write().unwrap() = Some(Instant::now());
|
||||
|
||||
match self.state() {
|
||||
CircuitState::Closed => {
|
||||
let count = self.failure_count.fetch_add(1, Ordering::SeqCst) + 1;
|
||||
if count >= self.config.failure_threshold {
|
||||
self.state.store(CircuitState::Open as u8, Ordering::SeqCst);
|
||||
tracing::warn!(
|
||||
failures = count,
|
||||
"Circuit breaker opened (too many failures)"
|
||||
);
|
||||
}
|
||||
}
|
||||
CircuitState::HalfOpen => {
|
||||
self.state.store(CircuitState::Open as u8, Ordering::SeqCst);
|
||||
self.success_count.store(0, Ordering::SeqCst);
|
||||
tracing::warn!("Circuit breaker reopened (half-open test failed)");
|
||||
}
|
||||
CircuitState::Open => {}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn call<F, T, E>(&self, f: F) -> Result<T, CircuitError<E>>
|
||||
where
|
||||
F: std::future::Future<Output = Result<T, E>>,
|
||||
{
|
||||
if !self.should_allow() {
|
||||
return Err(CircuitError::Open);
|
||||
}
|
||||
|
||||
match tokio::time::timeout(self.config.request_timeout, f).await {
|
||||
Ok(Ok(result)) => {
|
||||
self.record_success();
|
||||
Ok(result)
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
self.record_failure();
|
||||
Err(CircuitError::Inner(e))
|
||||
}
|
||||
Err(_) => {
|
||||
self.record_failure();
|
||||
Err(CircuitError::Timeout)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum CircuitError<E> {
|
||||
Open,
|
||||
Timeout,
|
||||
Inner(E),
|
||||
}
|
||||
5
crates/stratum-llm/src/chain/mod.rs
Normal file
5
crates/stratum-llm/src/chain/mod.rs
Normal file
@ -0,0 +1,5 @@
|
||||
pub mod circuit_breaker;
|
||||
pub mod provider_chain;
|
||||
|
||||
pub use circuit_breaker::{CircuitBreaker, CircuitBreakerConfig, CircuitError, CircuitState};
|
||||
pub use provider_chain::{FallbackStrategy, ProviderChain, ProviderInfo};
|
||||
208
crates/stratum-llm/src/chain/provider_chain.rs
Normal file
208
crates/stratum-llm/src/chain/provider_chain.rs
Normal file
@ -0,0 +1,208 @@
|
||||
use crate::chain::circuit_breaker::{CircuitBreaker, CircuitBreakerConfig, CircuitError};
|
||||
use crate::credentials::CredentialDetector;
|
||||
use crate::error::LlmError;
|
||||
use crate::providers::{
|
||||
ConfiguredProvider, GenerationOptions, GenerationResponse, LlmProvider, Message,
|
||||
};
|
||||
|
||||
pub struct ProviderChain {
|
||||
providers: Vec<ProviderWithCircuit>,
|
||||
strategy: FallbackStrategy,
|
||||
}
|
||||
|
||||
struct ProviderWithCircuit {
|
||||
provider: Box<dyn LlmProvider>,
|
||||
circuit: CircuitBreaker,
|
||||
is_subscription: bool,
|
||||
priority: u32,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum FallbackStrategy {
|
||||
Sequential,
|
||||
OnRateLimitOrUnavailable,
|
||||
OnBudgetExceeded { budget_cents: f64 },
|
||||
}
|
||||
|
||||
impl ProviderChain {
|
||||
pub fn from_detected() -> Result<Self, LlmError> {
|
||||
let detector = CredentialDetector::new();
|
||||
let credentials = detector.detect_all();
|
||||
|
||||
if credentials.is_empty() {
|
||||
return Err(LlmError::NoProvidersAvailable);
|
||||
}
|
||||
|
||||
let mut providers = Vec::new();
|
||||
|
||||
for cred in credentials {
|
||||
let provider: Option<Box<dyn LlmProvider>> = match cred.provider.as_str() {
|
||||
#[cfg(feature = "anthropic")]
|
||||
"anthropic" => Some(Box::new(
|
||||
crate::providers::AnthropicProvider::sonnet()
|
||||
.map_err(|_| LlmError::NoProvidersAvailable)?,
|
||||
)),
|
||||
#[cfg(feature = "openai")]
|
||||
"openai" => Some(Box::new(
|
||||
crate::providers::OpenAiProvider::gpt4o()
|
||||
.map_err(|_| LlmError::NoProvidersAvailable)?,
|
||||
)),
|
||||
#[cfg(feature = "deepseek")]
|
||||
"deepseek" => Some(Box::new(
|
||||
crate::providers::DeepSeekProvider::coder()
|
||||
.map_err(|_| LlmError::NoProvidersAvailable)?,
|
||||
)),
|
||||
#[cfg(feature = "ollama")]
|
||||
"ollama" => Some(Box::new(crate::providers::OllamaProvider::default())),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
if let Some(provider) = provider {
|
||||
providers.push(ProviderWithCircuit {
|
||||
provider,
|
||||
circuit: CircuitBreaker::new(CircuitBreakerConfig::default()),
|
||||
is_subscription: cred.is_subscription,
|
||||
priority: if cred.is_subscription { 0 } else { 10 },
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if providers.is_empty() {
|
||||
return Err(LlmError::NoProvidersAvailable);
|
||||
}
|
||||
|
||||
providers.sort_by_key(|p| p.priority);
|
||||
|
||||
Ok(Self {
|
||||
providers,
|
||||
strategy: FallbackStrategy::Sequential,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn with_providers(providers: Vec<ConfiguredProvider>) -> Self {
|
||||
let providers = providers
|
||||
.into_iter()
|
||||
.map(|p| ProviderWithCircuit {
|
||||
provider: p.provider,
|
||||
circuit: CircuitBreaker::new(CircuitBreakerConfig::default()),
|
||||
is_subscription: matches!(
|
||||
p.credential_source,
|
||||
crate::providers::CredentialSource::Cli { .. }
|
||||
),
|
||||
priority: p.priority,
|
||||
})
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
providers,
|
||||
strategy: FallbackStrategy::Sequential,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_strategy(mut self, strategy: FallbackStrategy) -> Self {
|
||||
self.strategy = strategy;
|
||||
self
|
||||
}
|
||||
|
||||
pub async fn generate(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
options: &GenerationOptions,
|
||||
) -> Result<GenerationResponse, LlmError> {
|
||||
let mut last_error: Option<LlmError> = None;
|
||||
|
||||
for pwc in &self.providers {
|
||||
if !pwc.circuit.should_allow() {
|
||||
tracing::debug!(provider = pwc.provider.name(), "Circuit open, skipping");
|
||||
continue;
|
||||
}
|
||||
|
||||
if let FallbackStrategy::OnBudgetExceeded { budget_cents } = &self.strategy {
|
||||
let estimated_cost = pwc.provider.estimate_cost(
|
||||
estimate_tokens(messages),
|
||||
options.max_tokens.unwrap_or(1000),
|
||||
);
|
||||
if estimated_cost > *budget_cents {
|
||||
tracing::debug!(
|
||||
provider = pwc.provider.name(),
|
||||
estimated_cost,
|
||||
budget = budget_cents,
|
||||
"Would exceed budget, trying next"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
match pwc
|
||||
.circuit
|
||||
.call(pwc.provider.generate(messages, options))
|
||||
.await
|
||||
{
|
||||
Ok(response) => {
|
||||
tracing::info!(
|
||||
provider = pwc.provider.name(),
|
||||
model = pwc.provider.model(),
|
||||
cost_cents = response.cost_cents,
|
||||
latency_ms = response.latency_ms,
|
||||
"Request successful"
|
||||
);
|
||||
return Ok(response);
|
||||
}
|
||||
Err(CircuitError::Open) => {
|
||||
tracing::debug!(provider = pwc.provider.name(), "Circuit open");
|
||||
continue;
|
||||
}
|
||||
Err(CircuitError::Timeout) => {
|
||||
tracing::warn!(provider = pwc.provider.name(), "Request timed out");
|
||||
last_error = Some(LlmError::Timeout);
|
||||
continue;
|
||||
}
|
||||
Err(CircuitError::Inner(e)) => {
|
||||
tracing::warn!(provider = pwc.provider.name(), error = %e, "Request failed");
|
||||
|
||||
let should_fallback = match &self.strategy {
|
||||
FallbackStrategy::Sequential => true,
|
||||
FallbackStrategy::OnRateLimitOrUnavailable => {
|
||||
matches!(e, LlmError::RateLimit(_) | LlmError::Unavailable(_))
|
||||
}
|
||||
FallbackStrategy::OnBudgetExceeded { .. } => true,
|
||||
};
|
||||
|
||||
if should_fallback {
|
||||
last_error = Some(e);
|
||||
continue;
|
||||
} else {
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(last_error.unwrap_or(LlmError::NoProvidersAvailable))
|
||||
}
|
||||
|
||||
pub fn provider_info(&self) -> Vec<ProviderInfo> {
|
||||
self.providers
|
||||
.iter()
|
||||
.map(|p| ProviderInfo {
|
||||
name: p.provider.name().to_string(),
|
||||
model: p.provider.model().to_string(),
|
||||
is_subscription: p.is_subscription,
|
||||
circuit_state: p.circuit.state(),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ProviderInfo {
|
||||
pub name: String,
|
||||
pub model: String,
|
||||
pub is_subscription: bool,
|
||||
pub circuit_state: crate::chain::circuit_breaker::CircuitState,
|
||||
}
|
||||
|
||||
fn estimate_tokens(messages: &[Message]) -> u32 {
|
||||
let total_chars: usize = messages.iter().map(|m| m.content.len()).sum();
|
||||
(total_chars / 4) as u32
|
||||
}
|
||||
167
crates/stratum-llm/src/client.rs
Normal file
167
crates/stratum-llm/src/client.rs
Normal file
@ -0,0 +1,167 @@
|
||||
use crate::cache::{CacheConfig, RequestCache};
|
||||
use crate::chain::ProviderChain;
|
||||
use crate::error::LlmError;
|
||||
#[cfg(feature = "kogral")]
|
||||
use crate::kogral::KogralIntegration;
|
||||
use crate::providers::{GenerationOptions, GenerationResponse, Message};
|
||||
|
||||
pub struct UnifiedClient {
|
||||
chain: ProviderChain,
|
||||
cache: RequestCache,
|
||||
#[cfg(feature = "kogral")]
|
||||
kogral: Option<KogralIntegration>,
|
||||
default_options: GenerationOptions,
|
||||
}
|
||||
|
||||
pub struct UnifiedClientBuilder {
|
||||
chain: Option<ProviderChain>,
|
||||
cache_config: CacheConfig,
|
||||
#[cfg(feature = "kogral")]
|
||||
kogral: Option<KogralIntegration>,
|
||||
default_options: GenerationOptions,
|
||||
}
|
||||
|
||||
impl UnifiedClientBuilder {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
chain: None,
|
||||
cache_config: CacheConfig::default(),
|
||||
#[cfg(feature = "kogral")]
|
||||
kogral: None,
|
||||
default_options: GenerationOptions::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn auto_detect(mut self) -> Result<Self, LlmError> {
|
||||
self.chain = Some(ProviderChain::from_detected()?);
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
pub fn with_chain(mut self, chain: ProviderChain) -> Self {
|
||||
self.chain = Some(chain);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_cache(mut self, config: CacheConfig) -> Self {
|
||||
self.cache_config = config;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn without_cache(mut self) -> Self {
|
||||
self.cache_config.enabled = false;
|
||||
self
|
||||
}
|
||||
|
||||
#[cfg(feature = "kogral")]
|
||||
pub fn with_kogral(mut self) -> Self {
|
||||
self.kogral = KogralIntegration::new();
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_defaults(mut self, options: GenerationOptions) -> Self {
|
||||
self.default_options = options;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> Result<UnifiedClient, LlmError> {
|
||||
let chain = self.chain.ok_or(LlmError::NoProvidersAvailable)?;
|
||||
|
||||
Ok(UnifiedClient {
|
||||
chain,
|
||||
cache: RequestCache::new(self.cache_config),
|
||||
#[cfg(feature = "kogral")]
|
||||
kogral: self.kogral,
|
||||
default_options: self.default_options,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for UnifiedClientBuilder {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl UnifiedClient {
|
||||
pub fn auto() -> Result<Self, LlmError> {
|
||||
UnifiedClientBuilder::new().auto_detect()?.build()
|
||||
}
|
||||
|
||||
pub fn builder() -> UnifiedClientBuilder {
|
||||
UnifiedClientBuilder::new()
|
||||
}
|
||||
|
||||
pub async fn generate(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
options: Option<&GenerationOptions>,
|
||||
) -> Result<GenerationResponse, LlmError> {
|
||||
let opts = options.unwrap_or(&self.default_options);
|
||||
|
||||
self.cache
|
||||
.get_or_generate(messages, opts, || self.chain.generate(messages, opts))
|
||||
.await
|
||||
}
|
||||
|
||||
#[cfg(feature = "kogral")]
|
||||
pub async fn generate_with_kogral(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
options: Option<&GenerationOptions>,
|
||||
language: Option<&str>,
|
||||
domain: Option<&str>,
|
||||
) -> Result<GenerationResponse, LlmError> {
|
||||
let opts = options.unwrap_or(&self.default_options);
|
||||
|
||||
let enriched_messages = if let Some(kogral) = &self.kogral {
|
||||
let mut ctx = serde_json::json!({});
|
||||
kogral
|
||||
.enrich_context(&mut ctx, language, domain)
|
||||
.await
|
||||
.map_err(|e| LlmError::Context(e.to_string()))?;
|
||||
|
||||
self.inject_kogral_context(messages, &ctx)
|
||||
} else {
|
||||
messages.to_vec()
|
||||
};
|
||||
|
||||
self.generate(&enriched_messages, Some(opts)).await
|
||||
}
|
||||
|
||||
#[cfg(feature = "kogral")]
|
||||
fn inject_kogral_context(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
kogral_ctx: &serde_json::Value,
|
||||
) -> Vec<Message> {
|
||||
let mut result = Vec::with_capacity(messages.len());
|
||||
let mut system_found = false;
|
||||
|
||||
for msg in messages {
|
||||
if matches!(msg.role, crate::providers::Role::System) && !system_found {
|
||||
let enhanced_content = format!(
|
||||
"{}\n\n## Project Context (from Kogral)\n{}",
|
||||
msg.content,
|
||||
serde_json::to_string_pretty(kogral_ctx).unwrap_or_default()
|
||||
);
|
||||
result.push(Message {
|
||||
role: msg.role,
|
||||
content: enhanced_content,
|
||||
});
|
||||
system_found = true;
|
||||
} else {
|
||||
result.push(msg.clone());
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
pub fn providers(&self) -> Vec<crate::chain::ProviderInfo> {
|
||||
self.chain.provider_info()
|
||||
}
|
||||
|
||||
pub fn clear_cache(&self) {
|
||||
self.cache.clear();
|
||||
}
|
||||
}
|
||||
68
crates/stratum-llm/src/credentials/claude_cli.rs
Normal file
68
crates/stratum-llm/src/credentials/claude_cli.rs
Normal file
@ -0,0 +1,68 @@
|
||||
#[cfg(feature = "claude-cli")]
|
||||
use crate::credentials::detector::DetectedCredential;
|
||||
#[cfg(feature = "claude-cli")]
|
||||
use crate::providers::CredentialSource;
|
||||
|
||||
#[cfg(feature = "claude-cli")]
|
||||
impl crate::credentials::CredentialDetector {
|
||||
pub fn detect_claude_cli(&self) -> Option<DetectedCredential> {
|
||||
let config_dir = dirs::config_dir()?;
|
||||
|
||||
let possible_paths = [
|
||||
config_dir.join("claude").join("credentials.json"),
|
||||
config_dir.join("claude-cli").join("auth.json"),
|
||||
config_dir.join("anthropic").join("credentials.json"),
|
||||
];
|
||||
|
||||
for path in &possible_paths {
|
||||
if let Some(cred) = Self::try_read_claude_credentials(path) {
|
||||
return Some(cred);
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn try_read_claude_credentials(path: &std::path::Path) -> Option<DetectedCredential> {
|
||||
if !path.exists() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let content = std::fs::read_to_string(path).ok()?;
|
||||
let json: serde_json::Value = serde_json::from_str(&content).ok()?;
|
||||
|
||||
let token = json.get("access_token")?;
|
||||
if !token.is_string() {
|
||||
return None;
|
||||
}
|
||||
|
||||
if Self::is_token_expired(&json) {
|
||||
tracing::debug!("Claude CLI token expired");
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(DetectedCredential {
|
||||
provider: "anthropic".to_string(),
|
||||
source: CredentialSource::Cli {
|
||||
path: path.to_path_buf(),
|
||||
},
|
||||
is_subscription: true,
|
||||
})
|
||||
}
|
||||
|
||||
fn is_token_expired(json: &serde_json::Value) -> bool {
|
||||
let Some(expires) = json.get("expires_at") else {
|
||||
return false;
|
||||
};
|
||||
|
||||
let Some(exp_str) = expires.as_str() else {
|
||||
return false;
|
||||
};
|
||||
|
||||
let Ok(exp) = chrono::DateTime::parse_from_rfc3339(exp_str) else {
|
||||
return false;
|
||||
};
|
||||
|
||||
exp < chrono::Utc::now()
|
||||
}
|
||||
}
|
||||
130
crates/stratum-llm/src/credentials/detector.rs
Normal file
130
crates/stratum-llm/src/credentials/detector.rs
Normal file
@ -0,0 +1,130 @@
|
||||
use crate::providers::CredentialSource;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DetectedCredential {
|
||||
pub provider: String,
|
||||
pub source: CredentialSource,
|
||||
pub is_subscription: bool,
|
||||
}
|
||||
|
||||
pub struct CredentialDetector {
|
||||
check_cli: bool,
|
||||
check_env: bool,
|
||||
}
|
||||
|
||||
impl CredentialDetector {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
check_cli: true,
|
||||
check_env: true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn without_cli(mut self) -> Self {
|
||||
self.check_cli = false;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn detect_all(&self) -> Vec<DetectedCredential> {
|
||||
let mut credentials = Vec::new();
|
||||
|
||||
if self.check_cli {
|
||||
#[cfg(feature = "claude-cli")]
|
||||
if let Some(cred) = self.detect_claude_cli() {
|
||||
credentials.push(cred);
|
||||
}
|
||||
}
|
||||
|
||||
if self.check_env {
|
||||
if let Some(cred) = self.detect_anthropic_env() {
|
||||
credentials.push(cred);
|
||||
}
|
||||
if let Some(cred) = self.detect_openai_env() {
|
||||
credentials.push(cred);
|
||||
}
|
||||
if let Some(cred) = self.detect_deepseek_env() {
|
||||
credentials.push(cred);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(cred) = self.detect_ollama() {
|
||||
credentials.push(cred);
|
||||
}
|
||||
|
||||
credentials
|
||||
}
|
||||
|
||||
pub fn detect_for_provider(&self, provider: &str) -> Option<DetectedCredential> {
|
||||
match provider {
|
||||
"anthropic" | "claude" => {
|
||||
#[cfg(feature = "claude-cli")]
|
||||
if self.check_cli {
|
||||
if let Some(cred) = self.detect_claude_cli() {
|
||||
return Some(cred);
|
||||
}
|
||||
}
|
||||
self.detect_anthropic_env()
|
||||
}
|
||||
"openai" => self.detect_openai_env(),
|
||||
"deepseek" => self.detect_deepseek_env(),
|
||||
"ollama" => self.detect_ollama(),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn detect_anthropic_env(&self) -> Option<DetectedCredential> {
|
||||
if std::env::var("ANTHROPIC_API_KEY").is_ok() {
|
||||
Some(DetectedCredential {
|
||||
provider: "anthropic".to_string(),
|
||||
source: CredentialSource::EnvVar {
|
||||
name: "ANTHROPIC_API_KEY".to_string(),
|
||||
},
|
||||
is_subscription: false,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn detect_openai_env(&self) -> Option<DetectedCredential> {
|
||||
if std::env::var("OPENAI_API_KEY").is_ok() {
|
||||
Some(DetectedCredential {
|
||||
provider: "openai".to_string(),
|
||||
source: CredentialSource::EnvVar {
|
||||
name: "OPENAI_API_KEY".to_string(),
|
||||
},
|
||||
is_subscription: false,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn detect_deepseek_env(&self) -> Option<DetectedCredential> {
|
||||
if std::env::var("DEEPSEEK_API_KEY").is_ok() {
|
||||
Some(DetectedCredential {
|
||||
provider: "deepseek".to_string(),
|
||||
source: CredentialSource::EnvVar {
|
||||
name: "DEEPSEEK_API_KEY".to_string(),
|
||||
},
|
||||
is_subscription: false,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn detect_ollama(&self) -> Option<DetectedCredential> {
|
||||
Some(DetectedCredential {
|
||||
provider: "ollama".to_string(),
|
||||
source: CredentialSource::None,
|
||||
is_subscription: false,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for CredentialDetector {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
6
crates/stratum-llm/src/credentials/mod.rs
Normal file
6
crates/stratum-llm/src/credentials/mod.rs
Normal file
@ -0,0 +1,6 @@
|
||||
pub mod detector;
|
||||
|
||||
#[cfg(feature = "claude-cli")]
|
||||
pub mod claude_cli;
|
||||
|
||||
pub use detector::{CredentialDetector, DetectedCredential};
|
||||
47
crates/stratum-llm/src/error.rs
Normal file
47
crates/stratum-llm/src/error.rs
Normal file
@ -0,0 +1,47 @@
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum LlmError {
|
||||
#[error("No providers available")]
|
||||
NoProvidersAvailable,
|
||||
|
||||
#[error("Missing credential: {0}")]
|
||||
MissingCredential(String),
|
||||
|
||||
#[error("Network error: {0}")]
|
||||
Network(String),
|
||||
|
||||
#[error("API error: {0}")]
|
||||
Api(String),
|
||||
|
||||
#[error("Rate limited: {0}")]
|
||||
RateLimit(String),
|
||||
|
||||
#[error("Provider unavailable: {0}")]
|
||||
Unavailable(String),
|
||||
|
||||
#[error("Request timeout")]
|
||||
Timeout,
|
||||
|
||||
#[error("Parse error: {0}")]
|
||||
Parse(String),
|
||||
|
||||
#[error("Context error: {0}")]
|
||||
Context(String),
|
||||
|
||||
#[error("Circuit breaker open for provider")]
|
||||
CircuitOpen,
|
||||
}
|
||||
|
||||
impl LlmError {
|
||||
pub fn is_rate_limit(&self) -> bool {
|
||||
matches!(self, Self::RateLimit(_))
|
||||
}
|
||||
|
||||
pub fn is_retriable(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
Self::Network(_) | Self::RateLimit(_) | Self::Timeout | Self::Unavailable(_)
|
||||
)
|
||||
}
|
||||
}
|
||||
216
crates/stratum-llm/src/kogral/integration.rs
Normal file
216
crates/stratum-llm/src/kogral/integration.rs
Normal file
@ -0,0 +1,216 @@
|
||||
#[cfg(feature = "kogral")]
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[cfg(feature = "kogral")]
|
||||
pub struct KogralIntegration {
|
||||
kogral_path: PathBuf,
|
||||
}
|
||||
|
||||
#[cfg(feature = "kogral")]
|
||||
impl KogralIntegration {
|
||||
pub fn new() -> Option<Self> {
|
||||
let possible_paths = [
|
||||
dirs::home_dir()?.join(".kogral"),
|
||||
PathBuf::from("/Users/Akasha/Development/kogral/.kogral"),
|
||||
];
|
||||
|
||||
for path in &possible_paths {
|
||||
if path.exists() {
|
||||
return Some(Self {
|
||||
kogral_path: path.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
pub fn with_path(path: impl Into<PathBuf>) -> Self {
|
||||
Self {
|
||||
kogral_path: path.into(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_guidelines(&self, language: &str) -> Result<Vec<Guideline>, KogralError> {
|
||||
let guidelines_dir = self.kogral_path.join("default");
|
||||
let mut guidelines = Vec::new();
|
||||
|
||||
if let Ok(entries) = std::fs::read_dir(&guidelines_dir) {
|
||||
for entry in entries.flatten() {
|
||||
let path = entry.path();
|
||||
if path.extension().is_none_or(|e| e != "md") {
|
||||
continue;
|
||||
}
|
||||
|
||||
let Ok(content) = std::fs::read_to_string(&path) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
if let Some(guideline) = Self::parse_guideline(&content, language) {
|
||||
guidelines.push(guideline);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(guidelines)
|
||||
}
|
||||
|
||||
pub async fn get_patterns(&self, domain: &str) -> Result<Vec<Pattern>, KogralError> {
|
||||
let patterns_dir = self.kogral_path.join("default");
|
||||
let mut patterns = Vec::new();
|
||||
|
||||
if let Ok(entries) = std::fs::read_dir(&patterns_dir) {
|
||||
for entry in entries.flatten() {
|
||||
let path = entry.path();
|
||||
if path.extension().is_none_or(|e| e != "md") {
|
||||
continue;
|
||||
}
|
||||
|
||||
let Ok(content) = std::fs::read_to_string(&path) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
if let Some(pattern) = Self::parse_pattern(&content, domain) {
|
||||
patterns.push(pattern);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(patterns)
|
||||
}
|
||||
|
||||
pub async fn enrich_context(
|
||||
&self,
|
||||
context: &mut serde_json::Value,
|
||||
language: Option<&str>,
|
||||
domain: Option<&str>,
|
||||
) -> Result<(), KogralError> {
|
||||
let mut kogral_context = serde_json::json!({});
|
||||
|
||||
if let Some(lang) = language {
|
||||
let guidelines = self.get_guidelines(lang).await?;
|
||||
if !guidelines.is_empty() {
|
||||
kogral_context["guidelines"] = serde_json::to_value(&guidelines)?;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(dom) = domain {
|
||||
let patterns = self.get_patterns(dom).await?;
|
||||
if !patterns.is_empty() {
|
||||
kogral_context["patterns"] = serde_json::to_value(&patterns)?;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(obj) = context.as_object_mut() {
|
||||
obj.insert("kogral".to_string(), kogral_context);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn parse_guideline(content: &str, language: &str) -> Option<Guideline> {
|
||||
let (frontmatter, body) = Self::split_frontmatter(content)?;
|
||||
let meta: serde_yaml::Value = serde_yaml::from_str(&frontmatter).ok()?;
|
||||
|
||||
let node_type = meta.get("node_type")?.as_str()?;
|
||||
if node_type != "guideline" {
|
||||
return None;
|
||||
}
|
||||
|
||||
let tags: Vec<String> = meta
|
||||
.get("tags")?
|
||||
.as_sequence()?
|
||||
.iter()
|
||||
.filter_map(|v| v.as_str().map(String::from))
|
||||
.collect();
|
||||
|
||||
if !tags.iter().any(|t| t.eq_ignore_ascii_case(language)) {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(Guideline {
|
||||
title: meta.get("title")?.as_str()?.to_string(),
|
||||
content: body.to_string(),
|
||||
tags,
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_pattern(content: &str, domain: &str) -> Option<Pattern> {
|
||||
let (frontmatter, body) = Self::split_frontmatter(content)?;
|
||||
let meta: serde_yaml::Value = serde_yaml::from_str(&frontmatter).ok()?;
|
||||
|
||||
let node_type = meta.get("node_type")?.as_str()?;
|
||||
if node_type != "pattern" {
|
||||
return None;
|
||||
}
|
||||
|
||||
let tags: Vec<String> = meta
|
||||
.get("tags")?
|
||||
.as_sequence()?
|
||||
.iter()
|
||||
.filter_map(|v| v.as_str().map(String::from))
|
||||
.collect();
|
||||
|
||||
if !tags.iter().any(|t| t.eq_ignore_ascii_case(domain)) {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(Pattern {
|
||||
title: meta.get("title")?.as_str()?.to_string(),
|
||||
content: body.to_string(),
|
||||
tags,
|
||||
})
|
||||
}
|
||||
|
||||
fn split_frontmatter(content: &str) -> Option<(String, String)> {
|
||||
let content = content.trim();
|
||||
if !content.starts_with("---") {
|
||||
return None;
|
||||
}
|
||||
|
||||
let after_first = &content[3..];
|
||||
let end = after_first.find("---")?;
|
||||
let frontmatter = after_first[..end].trim().to_string();
|
||||
let body = after_first[end + 3..].trim().to_string();
|
||||
|
||||
Some((frontmatter, body))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "kogral")]
|
||||
impl Default for KogralIntegration {
|
||||
fn default() -> Self {
|
||||
Self::new().unwrap_or_else(|| Self {
|
||||
kogral_path: PathBuf::from(".kogral"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "kogral")]
|
||||
#[derive(Debug, Clone, serde::Serialize)]
|
||||
pub struct Guideline {
|
||||
pub title: String,
|
||||
pub content: String,
|
||||
pub tags: Vec<String>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "kogral")]
|
||||
#[derive(Debug, Clone, serde::Serialize)]
|
||||
pub struct Pattern {
|
||||
pub title: String,
|
||||
pub content: String,
|
||||
pub tags: Vec<String>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "kogral")]
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum KogralError {
|
||||
#[error("Kogral not found")]
|
||||
NotFound,
|
||||
#[error("IO error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
#[error("Parse error: {0}")]
|
||||
Parse(String),
|
||||
#[error("Serialization error: {0}")]
|
||||
Serialization(#[from] serde_json::Error),
|
||||
}
|
||||
5
crates/stratum-llm/src/kogral/mod.rs
Normal file
5
crates/stratum-llm/src/kogral/mod.rs
Normal file
@ -0,0 +1,5 @@
|
||||
#[cfg(feature = "kogral")]
|
||||
pub mod integration;
|
||||
|
||||
#[cfg(feature = "kogral")]
|
||||
pub use integration::{Guideline, KogralError, KogralIntegration, Pattern};
|
||||
64
crates/stratum-llm/src/lib.rs
Normal file
64
crates/stratum-llm/src/lib.rs
Normal file
@ -0,0 +1,64 @@
|
||||
//! Unified LLM abstraction with CLI detection, fallback, and caching
|
||||
//!
|
||||
//! # Features
|
||||
//!
|
||||
//! - **Credential auto-detection**: Finds CLI credentials (Claude, OpenAI) and
|
||||
//! API keys
|
||||
//! - **Provider fallback**: Automatic failover with circuit breaker pattern
|
||||
//! - **Smart caching**: xxHash-based deduplication reduces duplicate API calls
|
||||
//! - **Kogral integration**: Inject project context from knowledge base
|
||||
//! - **Cost tracking**: Transparent cost estimation across providers
|
||||
//!
|
||||
//! # Quick Start
|
||||
//!
|
||||
//! ```no_run
|
||||
//! use stratum_llm::{UnifiedClient, Message, Role};
|
||||
//!
|
||||
//! #[tokio::main]
|
||||
//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
//! let client = UnifiedClient::auto()?;
|
||||
//!
|
||||
//! let messages = vec![
|
||||
//! Message {
|
||||
//! role: Role::User,
|
||||
//! content: "What is Rust?".to_string(),
|
||||
//! }
|
||||
//! ];
|
||||
//!
|
||||
//! let response = client.generate(&messages, None).await?;
|
||||
//! println!("{}", response.content);
|
||||
//!
|
||||
//! Ok(())
|
||||
//! }
|
||||
//! ```
|
||||
|
||||
pub mod cache;
|
||||
pub mod chain;
|
||||
pub mod client;
|
||||
pub mod credentials;
|
||||
pub mod error;
|
||||
pub mod kogral;
|
||||
pub mod metrics;
|
||||
pub mod providers;
|
||||
|
||||
pub use cache::{CacheConfig, RequestCache};
|
||||
pub use chain::{CircuitBreakerConfig, FallbackStrategy, ProviderChain};
|
||||
pub use client::{UnifiedClient, UnifiedClientBuilder};
|
||||
pub use credentials::{CredentialDetector, DetectedCredential};
|
||||
pub use error::LlmError;
|
||||
#[cfg(feature = "kogral")]
|
||||
pub use kogral::{Guideline, KogralIntegration, Pattern};
|
||||
#[cfg(feature = "metrics")]
|
||||
pub use metrics::LlmMetrics;
|
||||
#[cfg(feature = "anthropic")]
|
||||
pub use providers::AnthropicProvider;
|
||||
#[cfg(feature = "deepseek")]
|
||||
pub use providers::DeepSeekProvider;
|
||||
#[cfg(feature = "ollama")]
|
||||
pub use providers::OllamaProvider;
|
||||
#[cfg(feature = "openai")]
|
||||
pub use providers::OpenAiProvider;
|
||||
pub use providers::{
|
||||
ConfiguredProvider, CredentialSource, GenerationOptions, GenerationResponse, LlmProvider,
|
||||
Message, Role,
|
||||
};
|
||||
74
crates/stratum-llm/src/metrics.rs
Normal file
74
crates/stratum-llm/src/metrics.rs
Normal file
@ -0,0 +1,74 @@
|
||||
#[cfg(feature = "metrics")]
|
||||
use prometheus::{Counter, Histogram, IntGauge, Registry};
|
||||
|
||||
#[cfg(feature = "metrics")]
|
||||
pub struct LlmMetrics {
|
||||
pub requests_total: Counter,
|
||||
pub requests_success: Counter,
|
||||
pub requests_failed: Counter,
|
||||
pub cache_hits: Counter,
|
||||
pub cache_misses: Counter,
|
||||
pub circuit_opens: Counter,
|
||||
pub fallbacks: Counter,
|
||||
pub latency_seconds: Histogram,
|
||||
pub cost_cents: Counter,
|
||||
pub active_circuits_open: IntGauge,
|
||||
}
|
||||
|
||||
#[cfg(feature = "metrics")]
|
||||
impl LlmMetrics {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
requests_total: Counter::new("stratum_llm_requests_total", "Total LLM requests")
|
||||
.unwrap(),
|
||||
requests_success: Counter::new(
|
||||
"stratum_llm_requests_success_total",
|
||||
"Successful LLM requests",
|
||||
)
|
||||
.unwrap(),
|
||||
requests_failed: Counter::new(
|
||||
"stratum_llm_requests_failed_total",
|
||||
"Failed LLM requests",
|
||||
)
|
||||
.unwrap(),
|
||||
cache_hits: Counter::new("stratum_llm_cache_hits_total", "Cache hits").unwrap(),
|
||||
cache_misses: Counter::new("stratum_llm_cache_misses_total", "Cache misses").unwrap(),
|
||||
circuit_opens: Counter::new("stratum_llm_circuit_opens_total", "Circuit breaker opens")
|
||||
.unwrap(),
|
||||
fallbacks: Counter::new("stratum_llm_fallbacks_total", "Provider fallbacks").unwrap(),
|
||||
latency_seconds: Histogram::with_opts(
|
||||
prometheus::HistogramOpts::new("stratum_llm_latency_seconds", "Request latency")
|
||||
.buckets(vec![0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 30.0]),
|
||||
)
|
||||
.unwrap(),
|
||||
cost_cents: Counter::new("stratum_llm_cost_cents_total", "Total cost in cents")
|
||||
.unwrap(),
|
||||
active_circuits_open: IntGauge::new(
|
||||
"stratum_llm_circuits_open",
|
||||
"Currently open circuit breakers",
|
||||
)
|
||||
.unwrap(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn register(&self, registry: &Registry) -> Result<(), prometheus::Error> {
|
||||
registry.register(Box::new(self.requests_total.clone()))?;
|
||||
registry.register(Box::new(self.requests_success.clone()))?;
|
||||
registry.register(Box::new(self.requests_failed.clone()))?;
|
||||
registry.register(Box::new(self.cache_hits.clone()))?;
|
||||
registry.register(Box::new(self.cache_misses.clone()))?;
|
||||
registry.register(Box::new(self.circuit_opens.clone()))?;
|
||||
registry.register(Box::new(self.fallbacks.clone()))?;
|
||||
registry.register(Box::new(self.latency_seconds.clone()))?;
|
||||
registry.register(Box::new(self.cost_cents.clone()))?;
|
||||
registry.register(Box::new(self.active_circuits_open.clone()))?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "metrics")]
|
||||
impl Default for LlmMetrics {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
181
crates/stratum-llm/src/providers/anthropic.rs
Normal file
181
crates/stratum-llm/src/providers/anthropic.rs
Normal file
@ -0,0 +1,181 @@
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::error::LlmError;
|
||||
use crate::providers::{
|
||||
GenerationOptions, GenerationResponse, LlmProvider, Message, StreamResponse,
|
||||
};
|
||||
|
||||
#[cfg(feature = "anthropic")]
|
||||
pub struct AnthropicProvider {
|
||||
client: reqwest::Client,
|
||||
api_key: String,
|
||||
model: String,
|
||||
}
|
||||
|
||||
#[cfg(feature = "anthropic")]
|
||||
impl AnthropicProvider {
|
||||
const BASE_URL: &'static str = "https://api.anthropic.com/v1";
|
||||
const API_VERSION: &'static str = "2023-06-01";
|
||||
|
||||
pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
|
||||
Self {
|
||||
client: reqwest::Client::new(),
|
||||
api_key: api_key.into(),
|
||||
model: model.into(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_env(model: impl Into<String>) -> Result<Self, LlmError> {
|
||||
let api_key = std::env::var("ANTHROPIC_API_KEY")
|
||||
.map_err(|_| LlmError::MissingCredential("ANTHROPIC_API_KEY".to_string()))?;
|
||||
Ok(Self::new(api_key, model))
|
||||
}
|
||||
|
||||
pub fn sonnet() -> Result<Self, LlmError> {
|
||||
Self::from_env("claude-sonnet-4-5-20250929")
|
||||
}
|
||||
|
||||
pub fn opus() -> Result<Self, LlmError> {
|
||||
Self::from_env("claude-opus-4-5-20251101")
|
||||
}
|
||||
|
||||
pub fn haiku() -> Result<Self, LlmError> {
|
||||
Self::from_env("claude-haiku-4-5-20251001")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "anthropic")]
|
||||
#[async_trait]
|
||||
impl LlmProvider for AnthropicProvider {
|
||||
fn name(&self) -> &str {
|
||||
"anthropic"
|
||||
}
|
||||
|
||||
fn model(&self) -> &str {
|
||||
&self.model
|
||||
}
|
||||
|
||||
async fn is_available(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
async fn generate(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
options: &GenerationOptions,
|
||||
) -> Result<GenerationResponse, LlmError> {
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let (system, user_messages): (Vec<_>, Vec<_>) = messages
|
||||
.iter()
|
||||
.partition(|m| matches!(m.role, crate::providers::Role::System));
|
||||
|
||||
let system_content = system.first().map(|m| m.content.as_str());
|
||||
|
||||
let mut body = serde_json::json!({
|
||||
"model": self.model,
|
||||
"messages": user_messages.iter().map(|m| {
|
||||
serde_json::json!({
|
||||
"role": match m.role {
|
||||
crate::providers::Role::User => "user",
|
||||
crate::providers::Role::Assistant => "assistant",
|
||||
crate::providers::Role::System => "user",
|
||||
},
|
||||
"content": m.content,
|
||||
})
|
||||
}).collect::<Vec<_>>(),
|
||||
"max_tokens": options.max_tokens.unwrap_or(4096),
|
||||
});
|
||||
|
||||
if let Some(sys) = system_content {
|
||||
body["system"] = serde_json::json!(sys);
|
||||
}
|
||||
if let Some(temp) = options.temperature {
|
||||
body["temperature"] = serde_json::json!(temp);
|
||||
}
|
||||
if let Some(top_p) = options.top_p {
|
||||
body["top_p"] = serde_json::json!(top_p);
|
||||
}
|
||||
if !options.stop_sequences.is_empty() {
|
||||
body["stop_sequences"] = serde_json::json!(options.stop_sequences);
|
||||
}
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(format!("{}/messages", Self::BASE_URL))
|
||||
.header("x-api-key", &self.api_key)
|
||||
.header("anthropic-version", Self::API_VERSION)
|
||||
.header("content-type", "application/json")
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| LlmError::Network(e.to_string()))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
|
||||
if status.as_u16() == 429 {
|
||||
return Err(LlmError::RateLimit(text));
|
||||
}
|
||||
return Err(LlmError::Api(format!("{}: {}", status, text)));
|
||||
}
|
||||
|
||||
let json: serde_json::Value = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| LlmError::Parse(e.to_string()))?;
|
||||
|
||||
let content = json["content"][0]["text"]
|
||||
.as_str()
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
|
||||
let input_tokens = json["usage"]["input_tokens"].as_u64().unwrap_or(0) as u32;
|
||||
let output_tokens = json["usage"]["output_tokens"].as_u64().unwrap_or(0) as u32;
|
||||
|
||||
Ok(GenerationResponse {
|
||||
content,
|
||||
model: self.model.clone(),
|
||||
provider: "anthropic".to_string(),
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cost_cents: self.estimate_cost(input_tokens, output_tokens),
|
||||
latency_ms: start.elapsed().as_millis() as u64,
|
||||
})
|
||||
}
|
||||
|
||||
async fn stream(
|
||||
&self,
|
||||
_messages: &[Message],
|
||||
_options: &GenerationOptions,
|
||||
) -> Result<StreamResponse, LlmError> {
|
||||
Err(LlmError::Unavailable(
|
||||
"Streaming not yet implemented".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
fn estimate_cost(&self, input_tokens: u32, output_tokens: u32) -> f64 {
|
||||
let input_cost = (input_tokens as f64 / 1_000_000.0) * self.cost_per_1m_input();
|
||||
let output_cost = (output_tokens as f64 / 1_000_000.0) * self.cost_per_1m_output();
|
||||
(input_cost + output_cost) * 100.0
|
||||
}
|
||||
|
||||
fn cost_per_1m_input(&self) -> f64 {
|
||||
match self.model.as_str() {
|
||||
m if m.contains("opus") => 15.0,
|
||||
m if m.contains("sonnet") => 3.0,
|
||||
m if m.contains("haiku") => 1.0,
|
||||
_ => 3.0,
|
||||
}
|
||||
}
|
||||
|
||||
fn cost_per_1m_output(&self) -> f64 {
|
||||
match self.model.as_str() {
|
||||
m if m.contains("opus") => 75.0,
|
||||
m if m.contains("sonnet") => 15.0,
|
||||
m if m.contains("haiku") => 5.0,
|
||||
_ => 15.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
147
crates/stratum-llm/src/providers/deepseek.rs
Normal file
147
crates/stratum-llm/src/providers/deepseek.rs
Normal file
@ -0,0 +1,147 @@
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::error::LlmError;
|
||||
use crate::providers::{
|
||||
GenerationOptions, GenerationResponse, LlmProvider, Message, StreamResponse,
|
||||
};
|
||||
|
||||
#[cfg(feature = "deepseek")]
|
||||
pub struct DeepSeekProvider {
|
||||
client: reqwest::Client,
|
||||
api_key: String,
|
||||
model: String,
|
||||
}
|
||||
|
||||
#[cfg(feature = "deepseek")]
|
||||
impl DeepSeekProvider {
|
||||
const BASE_URL: &'static str = "https://api.deepseek.com/v1";
|
||||
|
||||
pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
|
||||
Self {
|
||||
client: reqwest::Client::new(),
|
||||
api_key: api_key.into(),
|
||||
model: model.into(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_env(model: impl Into<String>) -> Result<Self, LlmError> {
|
||||
let api_key = std::env::var("DEEPSEEK_API_KEY")
|
||||
.map_err(|_| LlmError::MissingCredential("DEEPSEEK_API_KEY".to_string()))?;
|
||||
Ok(Self::new(api_key, model))
|
||||
}
|
||||
|
||||
pub fn coder() -> Result<Self, LlmError> {
|
||||
Self::from_env("deepseek-coder")
|
||||
}
|
||||
|
||||
pub fn chat() -> Result<Self, LlmError> {
|
||||
Self::from_env("deepseek-chat")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "deepseek")]
|
||||
#[async_trait]
|
||||
impl LlmProvider for DeepSeekProvider {
|
||||
fn name(&self) -> &str {
|
||||
"deepseek"
|
||||
}
|
||||
|
||||
fn model(&self) -> &str {
|
||||
&self.model
|
||||
}
|
||||
|
||||
async fn is_available(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
async fn generate(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
options: &GenerationOptions,
|
||||
) -> Result<GenerationResponse, LlmError> {
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let body = serde_json::json!({
|
||||
"model": self.model,
|
||||
"messages": messages.iter().map(|m| {
|
||||
serde_json::json!({
|
||||
"role": match m.role {
|
||||
crate::providers::Role::System => "system",
|
||||
crate::providers::Role::User => "user",
|
||||
crate::providers::Role::Assistant => "assistant",
|
||||
},
|
||||
"content": m.content,
|
||||
})
|
||||
}).collect::<Vec<_>>(),
|
||||
"max_tokens": options.max_tokens.unwrap_or(4096),
|
||||
"temperature": options.temperature.unwrap_or(0.7),
|
||||
});
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(format!("{}/chat/completions", Self::BASE_URL))
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| LlmError::Network(e.to_string()))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
|
||||
if status.as_u16() == 429 {
|
||||
return Err(LlmError::RateLimit(text));
|
||||
}
|
||||
return Err(LlmError::Api(format!("{}: {}", status, text)));
|
||||
}
|
||||
|
||||
let json: serde_json::Value = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| LlmError::Parse(e.to_string()))?;
|
||||
|
||||
let content = json["choices"][0]["message"]["content"]
|
||||
.as_str()
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
|
||||
let input_tokens = json["usage"]["prompt_tokens"].as_u64().unwrap_or(0) as u32;
|
||||
let output_tokens = json["usage"]["completion_tokens"].as_u64().unwrap_or(0) as u32;
|
||||
|
||||
Ok(GenerationResponse {
|
||||
content,
|
||||
model: self.model.clone(),
|
||||
provider: "deepseek".to_string(),
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cost_cents: self.estimate_cost(input_tokens, output_tokens),
|
||||
latency_ms: start.elapsed().as_millis() as u64,
|
||||
})
|
||||
}
|
||||
|
||||
async fn stream(
|
||||
&self,
|
||||
_messages: &[Message],
|
||||
_options: &GenerationOptions,
|
||||
) -> Result<StreamResponse, LlmError> {
|
||||
Err(LlmError::Unavailable(
|
||||
"Streaming not yet implemented".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
fn estimate_cost(&self, input_tokens: u32, output_tokens: u32) -> f64 {
|
||||
let input_cost = (input_tokens as f64 / 1_000_000.0) * self.cost_per_1m_input();
|
||||
let output_cost = (output_tokens as f64 / 1_000_000.0) * self.cost_per_1m_output();
|
||||
(input_cost + output_cost) * 100.0
|
||||
}
|
||||
|
||||
fn cost_per_1m_input(&self) -> f64 {
|
||||
0.14
|
||||
}
|
||||
|
||||
fn cost_per_1m_output(&self) -> f64 {
|
||||
0.28
|
||||
}
|
||||
}
|
||||
23
crates/stratum-llm/src/providers/mod.rs
Normal file
23
crates/stratum-llm/src/providers/mod.rs
Normal file
@ -0,0 +1,23 @@
|
||||
pub mod traits;
|
||||
|
||||
#[cfg(feature = "anthropic")]
|
||||
pub mod anthropic;
|
||||
#[cfg(feature = "deepseek")]
|
||||
pub mod deepseek;
|
||||
#[cfg(feature = "ollama")]
|
||||
pub mod ollama;
|
||||
#[cfg(feature = "openai")]
|
||||
pub mod openai;
|
||||
|
||||
#[cfg(feature = "anthropic")]
|
||||
pub use anthropic::AnthropicProvider;
|
||||
#[cfg(feature = "deepseek")]
|
||||
pub use deepseek::DeepSeekProvider;
|
||||
#[cfg(feature = "ollama")]
|
||||
pub use ollama::OllamaProvider;
|
||||
#[cfg(feature = "openai")]
|
||||
pub use openai::OpenAiProvider;
|
||||
pub use traits::{
|
||||
ConfiguredProvider, CredentialSource, GenerationOptions, GenerationResponse, LlmProvider,
|
||||
Message, Role, StreamChunk, StreamResponse,
|
||||
};
|
||||
159
crates/stratum-llm/src/providers/ollama.rs
Normal file
159
crates/stratum-llm/src/providers/ollama.rs
Normal file
@ -0,0 +1,159 @@
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::error::LlmError;
|
||||
use crate::providers::{
|
||||
GenerationOptions, GenerationResponse, LlmProvider, Message, StreamResponse,
|
||||
};
|
||||
|
||||
#[cfg(feature = "ollama")]
|
||||
pub struct OllamaProvider {
|
||||
client: reqwest::Client,
|
||||
base_url: String,
|
||||
model: String,
|
||||
}
|
||||
|
||||
#[cfg(feature = "ollama")]
|
||||
impl OllamaProvider {
|
||||
pub fn new(base_url: impl Into<String>, model: impl Into<String>) -> Self {
|
||||
Self {
|
||||
client: reqwest::Client::new(),
|
||||
base_url: base_url.into(),
|
||||
model: model.into(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_env(model: impl Into<String>) -> Self {
|
||||
let base_url =
|
||||
std::env::var("OLLAMA_HOST").unwrap_or_else(|_| "http://localhost:11434".to_string());
|
||||
Self::new(base_url, model)
|
||||
}
|
||||
|
||||
pub fn llama3() -> Self {
|
||||
Self::from_env("llama3")
|
||||
}
|
||||
|
||||
pub fn codellama() -> Self {
|
||||
Self::from_env("codellama")
|
||||
}
|
||||
|
||||
pub fn mistral() -> Self {
|
||||
Self::from_env("mistral")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "ollama")]
|
||||
impl Default for OllamaProvider {
|
||||
fn default() -> Self {
|
||||
Self::llama3()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "ollama")]
|
||||
#[async_trait]
|
||||
impl LlmProvider for OllamaProvider {
|
||||
fn name(&self) -> &str {
|
||||
"ollama"
|
||||
}
|
||||
|
||||
fn model(&self) -> &str {
|
||||
&self.model
|
||||
}
|
||||
|
||||
async fn is_available(&self) -> bool {
|
||||
self.client
|
||||
.get(format!("{}/api/tags", self.base_url))
|
||||
.send()
|
||||
.await
|
||||
.is_ok()
|
||||
}
|
||||
|
||||
async fn generate(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
options: &GenerationOptions,
|
||||
) -> Result<GenerationResponse, LlmError> {
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let prompt = messages
|
||||
.iter()
|
||||
.map(|m| match m.role {
|
||||
crate::providers::Role::System => format!("System: {}", m.content),
|
||||
crate::providers::Role::User => format!("User: {}", m.content),
|
||||
crate::providers::Role::Assistant => format!("Assistant: {}", m.content),
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n\n");
|
||||
|
||||
let mut body = serde_json::json!({
|
||||
"model": self.model,
|
||||
"prompt": prompt,
|
||||
"stream": false,
|
||||
});
|
||||
|
||||
if let Some(temp) = options.temperature {
|
||||
body["temperature"] = serde_json::json!(temp);
|
||||
}
|
||||
if let Some(top_p) = options.top_p {
|
||||
body["top_p"] = serde_json::json!(top_p);
|
||||
}
|
||||
if !options.stop_sequences.is_empty() {
|
||||
body["stop"] = serde_json::json!(options.stop_sequences);
|
||||
}
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(format!("{}/api/generate", self.base_url))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| LlmError::Network(e.to_string()))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
return Err(LlmError::Api(format!("{}: {}", status, text)));
|
||||
}
|
||||
|
||||
let json: serde_json::Value = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| LlmError::Parse(e.to_string()))?;
|
||||
|
||||
let content = json["response"].as_str().unwrap_or("").to_string();
|
||||
|
||||
let input_tokens = prompt.len() as u32 / 4;
|
||||
let output_tokens = content.len() as u32 / 4;
|
||||
|
||||
Ok(GenerationResponse {
|
||||
content,
|
||||
model: self.model.clone(),
|
||||
provider: "ollama".to_string(),
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cost_cents: 0.0,
|
||||
latency_ms: start.elapsed().as_millis() as u64,
|
||||
})
|
||||
}
|
||||
|
||||
async fn stream(
|
||||
&self,
|
||||
_messages: &[Message],
|
||||
_options: &GenerationOptions,
|
||||
) -> Result<StreamResponse, LlmError> {
|
||||
Err(LlmError::Unavailable(
|
||||
"Streaming not yet implemented".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
fn estimate_cost(&self, _input_tokens: u32, _output_tokens: u32) -> f64 {
|
||||
0.0
|
||||
}
|
||||
|
||||
fn cost_per_1m_input(&self) -> f64 {
|
||||
0.0
|
||||
}
|
||||
|
||||
fn cost_per_1m_output(&self) -> f64 {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
161
crates/stratum-llm/src/providers/openai.rs
Normal file
161
crates/stratum-llm/src/providers/openai.rs
Normal file
@ -0,0 +1,161 @@
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::error::LlmError;
|
||||
use crate::providers::{
|
||||
GenerationOptions, GenerationResponse, LlmProvider, Message, StreamResponse,
|
||||
};
|
||||
|
||||
#[cfg(feature = "openai")]
|
||||
pub struct OpenAiProvider {
|
||||
client: reqwest::Client,
|
||||
api_key: String,
|
||||
model: String,
|
||||
}
|
||||
|
||||
#[cfg(feature = "openai")]
|
||||
impl OpenAiProvider {
|
||||
const BASE_URL: &'static str = "https://api.openai.com/v1";
|
||||
|
||||
pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
|
||||
Self {
|
||||
client: reqwest::Client::new(),
|
||||
api_key: api_key.into(),
|
||||
model: model.into(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_env(model: impl Into<String>) -> Result<Self, LlmError> {
|
||||
let api_key = std::env::var("OPENAI_API_KEY")
|
||||
.map_err(|_| LlmError::MissingCredential("OPENAI_API_KEY".to_string()))?;
|
||||
Ok(Self::new(api_key, model))
|
||||
}
|
||||
|
||||
pub fn gpt4o() -> Result<Self, LlmError> {
|
||||
Self::from_env("gpt-4o")
|
||||
}
|
||||
|
||||
pub fn gpt4_turbo() -> Result<Self, LlmError> {
|
||||
Self::from_env("gpt-4-turbo")
|
||||
}
|
||||
|
||||
pub fn gpt35_turbo() -> Result<Self, LlmError> {
|
||||
Self::from_env("gpt-3.5-turbo")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "openai")]
|
||||
#[async_trait]
|
||||
impl LlmProvider for OpenAiProvider {
|
||||
fn name(&self) -> &str {
|
||||
"openai"
|
||||
}
|
||||
|
||||
fn model(&self) -> &str {
|
||||
&self.model
|
||||
}
|
||||
|
||||
async fn is_available(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
async fn generate(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
options: &GenerationOptions,
|
||||
) -> Result<GenerationResponse, LlmError> {
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let body = serde_json::json!({
|
||||
"model": self.model,
|
||||
"messages": messages.iter().map(|m| {
|
||||
serde_json::json!({
|
||||
"role": match m.role {
|
||||
crate::providers::Role::System => "system",
|
||||
crate::providers::Role::User => "user",
|
||||
crate::providers::Role::Assistant => "assistant",
|
||||
},
|
||||
"content": m.content,
|
||||
})
|
||||
}).collect::<Vec<_>>(),
|
||||
"max_tokens": options.max_tokens.unwrap_or(4096),
|
||||
"temperature": options.temperature.unwrap_or(0.7),
|
||||
});
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(format!("{}/chat/completions", Self::BASE_URL))
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| LlmError::Network(e.to_string()))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
|
||||
if status.as_u16() == 429 {
|
||||
return Err(LlmError::RateLimit(text));
|
||||
}
|
||||
return Err(LlmError::Api(format!("{}: {}", status, text)));
|
||||
}
|
||||
|
||||
let json: serde_json::Value = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| LlmError::Parse(e.to_string()))?;
|
||||
|
||||
let content = json["choices"][0]["message"]["content"]
|
||||
.as_str()
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
|
||||
let input_tokens = json["usage"]["prompt_tokens"].as_u64().unwrap_or(0) as u32;
|
||||
let output_tokens = json["usage"]["completion_tokens"].as_u64().unwrap_or(0) as u32;
|
||||
|
||||
Ok(GenerationResponse {
|
||||
content,
|
||||
model: self.model.clone(),
|
||||
provider: "openai".to_string(),
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cost_cents: self.estimate_cost(input_tokens, output_tokens),
|
||||
latency_ms: start.elapsed().as_millis() as u64,
|
||||
})
|
||||
}
|
||||
|
||||
async fn stream(
|
||||
&self,
|
||||
_messages: &[Message],
|
||||
_options: &GenerationOptions,
|
||||
) -> Result<StreamResponse, LlmError> {
|
||||
Err(LlmError::Unavailable(
|
||||
"Streaming not yet implemented".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
fn estimate_cost(&self, input_tokens: u32, output_tokens: u32) -> f64 {
|
||||
let input_cost = (input_tokens as f64 / 1_000_000.0) * self.cost_per_1m_input();
|
||||
let output_cost = (output_tokens as f64 / 1_000_000.0) * self.cost_per_1m_output();
|
||||
(input_cost + output_cost) * 100.0
|
||||
}
|
||||
|
||||
fn cost_per_1m_input(&self) -> f64 {
|
||||
match self.model.as_str() {
|
||||
"gpt-4o" => 5.0,
|
||||
"gpt-4-turbo" => 10.0,
|
||||
"gpt-3.5-turbo" => 0.5,
|
||||
_ => 5.0,
|
||||
}
|
||||
}
|
||||
|
||||
fn cost_per_1m_output(&self) -> f64 {
|
||||
match self.model.as_str() {
|
||||
"gpt-4o" => 15.0,
|
||||
"gpt-4-turbo" => 30.0,
|
||||
"gpt-3.5-turbo" => 1.5,
|
||||
_ => 15.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
95
crates/stratum-llm/src/providers/traits.rs
Normal file
95
crates/stratum-llm/src/providers/traits.rs
Normal file
@ -0,0 +1,95 @@
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Message {
|
||||
pub role: Role,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Role {
|
||||
System,
|
||||
User,
|
||||
Assistant,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct GenerationOptions {
|
||||
pub temperature: Option<f32>,
|
||||
pub max_tokens: Option<u32>,
|
||||
pub top_p: Option<f32>,
|
||||
pub stop_sequences: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GenerationResponse {
|
||||
pub content: String,
|
||||
pub model: String,
|
||||
pub provider: String,
|
||||
pub input_tokens: u32,
|
||||
pub output_tokens: u32,
|
||||
pub cost_cents: f64,
|
||||
pub latency_ms: u64,
|
||||
}
|
||||
|
||||
pub type StreamChunk = String;
|
||||
pub type StreamResponse = std::pin::Pin<
|
||||
Box<dyn futures::Stream<Item = Result<StreamChunk, crate::error::LlmError>> + Send>,
|
||||
>;
|
||||
|
||||
#[async_trait]
|
||||
pub trait LlmProvider: Send + Sync {
|
||||
/// Provider name (e.g., "anthropic", "openai", "ollama")
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// Model identifier
|
||||
fn model(&self) -> &str;
|
||||
|
||||
/// Check if provider is available and configured
|
||||
async fn is_available(&self) -> bool;
|
||||
|
||||
/// Generate a completion
|
||||
async fn generate(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
options: &GenerationOptions,
|
||||
) -> Result<GenerationResponse, crate::error::LlmError>;
|
||||
|
||||
/// Stream a completion (future work)
|
||||
async fn stream(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
options: &GenerationOptions,
|
||||
) -> Result<StreamResponse, crate::error::LlmError>;
|
||||
|
||||
/// Estimate cost for a request (before sending)
|
||||
fn estimate_cost(&self, input_tokens: u32, output_tokens: u32) -> f64;
|
||||
|
||||
/// Cost per 1M input tokens in cents
|
||||
fn cost_per_1m_input(&self) -> f64;
|
||||
|
||||
/// Cost per 1M output tokens in cents
|
||||
fn cost_per_1m_output(&self) -> f64;
|
||||
}
|
||||
|
||||
/// Credential source for a provider
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum CredentialSource {
|
||||
/// CLI tool credentials (subscription-based, no per-token cost)
|
||||
Cli { path: std::path::PathBuf },
|
||||
/// API key from environment variable
|
||||
EnvVar { name: String },
|
||||
/// API key from config file
|
||||
ConfigFile { path: std::path::PathBuf },
|
||||
/// No credentials needed (local provider)
|
||||
None,
|
||||
}
|
||||
|
||||
/// Provider with credential metadata
|
||||
pub struct ConfiguredProvider {
|
||||
pub provider: Box<dyn LlmProvider>,
|
||||
pub credential_source: CredentialSource,
|
||||
pub priority: u32,
|
||||
}
|
||||
@ -32,6 +32,15 @@ Infrastructure automation and deployment tools.
|
||||
|
||||
See [Operations Portfolio Docs](en/ops/) for technical details.
|
||||
|
||||
### Architecture
|
||||
|
||||
Cross-cutting architectural decisions documented as ADRs.
|
||||
|
||||
- [ADR-001: Stratum-Embeddings](en/architecture/adrs/001-stratum-embeddings.md) - Unified embedding library
|
||||
- [ADR-002: Stratum-LLM](en/architecture/adrs/002-stratum-llm.md) - Unified LLM provider library
|
||||
|
||||
See [Architecture Docs](en/architecture/) for all ADRs.
|
||||
|
||||
## Quick Start
|
||||
|
||||
1. Choose your language: [English](en/) | [Español](es/)
|
||||
@ -47,3 +56,4 @@ Each language directory contains:
|
||||
- `stratiumiops-technical-specs.md` - Technical specifications
|
||||
- `ia/` - AI portfolio documentation
|
||||
- `ops/` - Operations portfolio documentation
|
||||
- `architecture/` - Architecture documentation and ADRs
|
||||
|
||||
@ -34,6 +34,16 @@ Infrastructure automation and deployment tools.
|
||||
|
||||
See [ops/](ops/) directory for full operations portfolio documentation.
|
||||
|
||||
### Architecture
|
||||
|
||||
Architectural decisions and ecosystem design.
|
||||
|
||||
- [**ADRs**](architecture/adrs/) - Architecture Decision Records
|
||||
- [ADR-001: Stratum-Embeddings](architecture/adrs/001-stratum-embeddings.md) - Unified embedding library
|
||||
- [ADR-002: Stratum-LLM](architecture/adrs/002-stratum-llm.md) - Unified LLM provider library
|
||||
|
||||
See [architecture/](architecture/) directory for full architecture documentation.
|
||||
|
||||
## Navigation
|
||||
|
||||
- [Back to root documentation](../)
|
||||
|
||||
30
docs/en/architecture/README.md
Normal file
30
docs/en/architecture/README.md
Normal file
@ -0,0 +1,30 @@
|
||||
# Architecture
|
||||
|
||||
Architecture documentation for the STRATUMIOPS ecosystem.
|
||||
|
||||
## Contents
|
||||
|
||||
### ADRs (Architecture Decision Records)
|
||||
|
||||
Documented architectural decisions following the ADR format:
|
||||
|
||||
- [**ADR-001: Stratum-Embeddings**](adrs/001-stratum-embeddings.md) - Unified embedding library
|
||||
- [**ADR-002: Stratum-LLM**](adrs/002-stratum-llm.md) - Unified LLM provider library
|
||||
|
||||
## ADR Format
|
||||
|
||||
Each ADR follows this structure:
|
||||
|
||||
| Section | Description |
|
||||
| --------------- | ------------------------------------------ |
|
||||
| Status | Proposed, Accepted, Deprecated, Superseded |
|
||||
| Context | Problem and current state |
|
||||
| Decision | Chosen solution |
|
||||
| Rationale | Why this solution |
|
||||
| Consequences | Positive, negative, mitigations |
|
||||
| Success Metrics | How to measure the outcome |
|
||||
|
||||
## Navigation
|
||||
|
||||
- [Back to main documentation](../)
|
||||
- [Spanish version](../../es/architecture/)
|
||||
279
docs/en/architecture/adrs/001-stratum-embeddings.md
Normal file
279
docs/en/architecture/adrs/001-stratum-embeddings.md
Normal file
@ -0,0 +1,279 @@
|
||||
# ADR-001: Stratum-Embeddings - Unified Embedding Library
|
||||
|
||||
## Status
|
||||
|
||||
**Proposed**
|
||||
|
||||
## Context
|
||||
|
||||
### Current State: Fragmented Implementations
|
||||
|
||||
The ecosystem has 3 independent embedding implementations:
|
||||
|
||||
| Project | Location | Providers | Caching |
|
||||
| ------------ | ------------------------------------- | ----------------------------- | ------- |
|
||||
| Kogral | `kogral-core/src/embeddings/` | fastembed, rig-core (partial) | No |
|
||||
| Provisioning | `provisioning-rag/src/embeddings.rs` | OpenAI direct | No |
|
||||
| Vapora | `vapora-llm-router/src/embeddings.rs` | OpenAI, HuggingFace, Ollama | No |
|
||||
|
||||
### Identified Problems
|
||||
|
||||
#### 1. Duplicated Code
|
||||
|
||||
Each project reimplements:
|
||||
|
||||
- HTTP client for OpenAI embeddings
|
||||
- JSON response parsing
|
||||
- Error handling
|
||||
- Token estimation
|
||||
|
||||
**Impact**: ~400 duplicated lines, inconsistent error handling.
|
||||
|
||||
#### 2. No Caching
|
||||
|
||||
Embeddings regenerated every time:
|
||||
|
||||
```text
|
||||
"What is Rust?" → OpenAI → 1536 dims → $0.00002
|
||||
"What is Rust?" → OpenAI → 1536 dims → $0.00002 (same result)
|
||||
"What is Rust?" → OpenAI → 1536 dims → $0.00002 (same result)
|
||||
```
|
||||
|
||||
**Impact**: Unnecessary costs, additional latency, more frequent rate limits.
|
||||
|
||||
#### 3. No Fallback
|
||||
|
||||
If OpenAI fails, everything fails. No fallback to local alternatives (fastembed, Ollama).
|
||||
|
||||
**Impact**: Reduced availability, total dependency on one provider.
|
||||
|
||||
#### 4. Silent Dimension Mismatch
|
||||
|
||||
Different providers produce different dimensions:
|
||||
|
||||
| Provider | Model | Dimensions |
|
||||
| --------- | ---------------------- | ---------- |
|
||||
| fastembed | bge-small-en | 384 |
|
||||
| fastembed | bge-large-en | 1024 |
|
||||
| OpenAI | text-embedding-3-small | 1536 |
|
||||
| OpenAI | text-embedding-3-large | 3072 |
|
||||
| Ollama | nomic-embed-text | 768 |
|
||||
|
||||
**Impact**: Corrupt vector indices if provider changes.
|
||||
|
||||
#### 5. No Metrics
|
||||
|
||||
No visibility into usage, cache hit rate, latency per provider, or accumulated costs.
|
||||
|
||||
## Decision
|
||||
|
||||
Create `stratum-embeddings` as a unified crate that:
|
||||
|
||||
1. **Unifies** implementations from Kogral, Provisioning, and Vapora
|
||||
2. **Adds caching** to avoid recomputing identical embeddings
|
||||
3. **Implements fallback** between providers (cloud → local)
|
||||
4. **Clearly documents** dimensions and limitations per provider
|
||||
5. **Exposes metrics** for observability
|
||||
6. **Provides VectorStore trait** with LanceDB and SurrealDB backends based on project needs
|
||||
|
||||
### Storage Backend Decision
|
||||
|
||||
Each project chooses its vector storage backend based on priority:
|
||||
|
||||
| Project | Backend | Priority | Justification |
|
||||
| ------------ | --------- | -------------- | -------------------------------------------------- |
|
||||
| Kogral | SurrealDB | Graph richness | Knowledge Graph needs unified graph+vector queries |
|
||||
| Provisioning | LanceDB | Vector scale | RAG with millions of document chunks |
|
||||
| Vapora | LanceDB | Vector scale | Execution traces, pattern matching at scale |
|
||||
|
||||
#### Why SurrealDB for Kogral
|
||||
|
||||
Kogral is a Knowledge Graph where relationships are the primary value.
|
||||
With hybrid architecture (LanceDB vectors + SurrealDB graph), a typical query would require:
|
||||
|
||||
1. LanceDB: vector search → candidate_ids
|
||||
2. SurrealDB: graph filter on candidates → results
|
||||
3. App layer: merge, re-rank, deduplication
|
||||
|
||||
**Accepted trade-off**: SurrealDB has worse pure vector performance than LanceDB,
|
||||
but Kogral's scale is limited by human curation of knowledge (typically 10K-100K concepts).
|
||||
|
||||
#### Why LanceDB for Provisioning and Vapora
|
||||
|
||||
| Aspect | SurrealDB | LanceDB |
|
||||
| --------------- | ---------- | -------------------- |
|
||||
| Storage format | Row-based | Columnar (Lance) |
|
||||
| Vector index | HNSW (RAM) | IVF-PQ (disk-native) |
|
||||
| Practical scale | Millions | Billions |
|
||||
| Compression | ~1x | ~32x (PQ) |
|
||||
| Zero-copy read | No | Yes |
|
||||
|
||||
### Architecture
|
||||
|
||||
```text
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ stratum-embeddings │
|
||||
├─────────────────────────────────────────────────────────────────┤
|
||||
│ EmbeddingProvider trait │
|
||||
│ ├─ embed(text) → Vec<f32> │
|
||||
│ ├─ embed_batch(texts) → Vec<Vec<f32>> │
|
||||
│ ├─ dimensions() → usize │
|
||||
│ └─ is_local() → bool │
|
||||
│ │
|
||||
│ ┌───────────┐ ┌───────────┐ ┌───────────┐ │
|
||||
│ │ FastEmbed │ │ OpenAI │ │ Ollama │ │
|
||||
│ │ (local) │ │ (cloud) │ │ (local) │ │
|
||||
│ └───────────┘ └───────────┘ └───────────┘ │
|
||||
│ └────────────┬────────────┘ │
|
||||
│ ▼ │
|
||||
│ EmbeddingCache (memory/disk) │
|
||||
│ │ │
|
||||
│ ▼ │
|
||||
│ EmbeddingService │
|
||||
│ │ │
|
||||
│ ▼ │
|
||||
│ ┌──────────────────────────────────────────────────────────┐ │
|
||||
│ │ VectorStore trait │ │
|
||||
│ │ ├─ upsert(id, embedding, metadata) │ │
|
||||
│ │ ├─ search(embedding, limit, filter) → Vec<Match> │ │
|
||||
│ │ └─ delete(id) │ │
|
||||
│ └──────────────────────────────────────────────────────────┘ │
|
||||
│ │ │ │
|
||||
│ ▼ ▼ │
|
||||
│ ┌─────────────────┐ ┌─────────────────┐ │
|
||||
│ │ SurrealDbStore │ │ LanceDbStore │ │
|
||||
│ │ (Kogral) │ │ (Prov/Vapora) │ │
|
||||
│ └─────────────────┘ └─────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
## Rationale
|
||||
|
||||
### Why Caching is Critical
|
||||
|
||||
For a typical RAG system (10,000 chunks):
|
||||
|
||||
- **Without cache**: Re-indexing and repeated queries multiply costs
|
||||
- **With cache**: First indexing pays, rest are cache hits
|
||||
|
||||
**Estimated savings**: 60-80% in embedding costs.
|
||||
|
||||
### Why Fallback is Important
|
||||
|
||||
| Scenario | Without Fallback | With Fallback |
|
||||
| ----------------- | ---------------- | -------------------- |
|
||||
| OpenAI rate limit | ERROR | → fastembed (local) |
|
||||
| OpenAI downtime | ERROR | → Ollama (local) |
|
||||
| No internet | ERROR | → fastembed (local) |
|
||||
|
||||
### Why Local Providers First
|
||||
|
||||
For development: fastembed loads local model (~100MB), no API keys required, no costs, works offline.
|
||||
|
||||
For production: OpenAI for quality, fastembed as fallback.
|
||||
|
||||
## Consequences
|
||||
|
||||
### Positive
|
||||
|
||||
1. Single source of truth for the entire ecosystem
|
||||
2. 60-80% fewer embedding API calls (caching)
|
||||
3. High availability with local providers (fallback)
|
||||
4. Usage and cost metrics
|
||||
5. Feature-gated: only compile what you need
|
||||
6. Storage flexibility: VectorStore trait allows choosing backend per project
|
||||
|
||||
### Negative
|
||||
|
||||
1. **Dimension lock-in**: Changing provider requires re-indexing
|
||||
2. **Cache invalidation**: Updated content may serve stale embeddings
|
||||
3. **Model download**: fastembed downloads ~100MB on first use
|
||||
4. **Storage lock-in per project**: Kogral tied to SurrealDB, others to LanceDB
|
||||
|
||||
### Mitigations
|
||||
|
||||
| Negative | Mitigation |
|
||||
| ----------------- | ---------------------------------------------- |
|
||||
| Dimension lock-in | Document clearly, warn on provider change |
|
||||
| Stale cache | Configurable TTL, bypass option |
|
||||
| Model download | Show progress, cache in ~/.cache/fastembed |
|
||||
| Storage lock-in | Conscious decision based on project priorities |
|
||||
|
||||
## Success Metrics
|
||||
|
||||
| Metric | Current | Target |
|
||||
| ------------------------- | ------- | ------ |
|
||||
| Duplicate implementations | 3 | 1 |
|
||||
| Cache hit rate | 0% | >60% |
|
||||
| Fallback availability | 0% | 100% |
|
||||
| Cost per 10K embeddings | ~$0.20 | ~$0.05 |
|
||||
|
||||
## Provider Selection Guide
|
||||
|
||||
### Development
|
||||
|
||||
```rust
|
||||
// Local, free, offline
|
||||
let service = EmbeddingService::builder()
|
||||
.with_provider(FastEmbedProvider::small()?) // 384 dims
|
||||
.with_memory_cache()
|
||||
.build()?;
|
||||
```
|
||||
|
||||
### Production (Quality)
|
||||
|
||||
```rust
|
||||
// OpenAI with local fallback
|
||||
let service = EmbeddingService::builder()
|
||||
.with_provider(OpenAiEmbeddingProvider::large()?) // 3072 dims
|
||||
.with_provider(FastEmbedProvider::large()?) // Fallback
|
||||
.with_memory_cache()
|
||||
.build()?;
|
||||
```
|
||||
|
||||
### Production (Cost-Optimized)
|
||||
|
||||
```rust
|
||||
// OpenAI small with fallback
|
||||
let service = EmbeddingService::builder()
|
||||
.with_provider(OpenAiEmbeddingProvider::small()?) // 1536 dims
|
||||
.with_provider(OllamaEmbeddingProvider::nomic()) // Fallback
|
||||
.with_memory_cache()
|
||||
.build()?;
|
||||
```
|
||||
|
||||
## Dimension Compatibility Matrix
|
||||
|
||||
| If using... | Can switch to... | CANNOT switch to... |
|
||||
| ---------------------- | --------------------------- | ------------------- |
|
||||
| fastembed small (384) | fastembed small, all-minilm | Any other |
|
||||
| fastembed large (1024) | fastembed large | Any other |
|
||||
| OpenAI small (1536) | OpenAI small, ada-002 | Any other |
|
||||
| OpenAI large (3072) | OpenAI large | Any other |
|
||||
|
||||
**Rule**: Only switch between models with the SAME dimensions.
|
||||
|
||||
## Implementation Priority
|
||||
|
||||
| Order | Feature | Reason |
|
||||
| ----- | ----------------------- | -------------------------- |
|
||||
| 1 | EmbeddingProvider trait | Foundation for everything |
|
||||
| 2 | FastEmbed provider | Works without API keys |
|
||||
| 3 | Memory cache | Biggest cost impact |
|
||||
| 4 | VectorStore trait | Storage abstraction |
|
||||
| 5 | SurrealDbStore | Kogral needs graph+vector |
|
||||
| 6 | LanceDbStore | Provisioning/Vapora scale |
|
||||
| 7 | OpenAI provider | Production |
|
||||
| 8 | Ollama provider | Local fallback |
|
||||
| 9 | Batch processing | Efficiency |
|
||||
| 10 | Metrics | Observability |
|
||||
|
||||
## References
|
||||
|
||||
**Existing Implementations**:
|
||||
|
||||
- Kogral: `kogral-core/src/embeddings/`
|
||||
- Vapora: `vapora-llm-router/src/embeddings.rs`
|
||||
- Provisioning: `provisioning/platform/crates/rag/src/embeddings.rs`
|
||||
|
||||
**Target Location**: `stratumiops/crates/stratum-embeddings/`
|
||||
279
docs/en/architecture/adrs/002-stratum-llm.md
Normal file
279
docs/en/architecture/adrs/002-stratum-llm.md
Normal file
@ -0,0 +1,279 @@
|
||||
# ADR-002: Stratum-LLM - Unified LLM Provider Library
|
||||
|
||||
## Status
|
||||
|
||||
**Proposed**
|
||||
|
||||
## Context
|
||||
|
||||
### Current State: Fragmented LLM Connections
|
||||
|
||||
The stratumiops ecosystem has 4 projects with AI functionality, each with its own implementation:
|
||||
|
||||
| Project | Implementation | Providers | Duplication |
|
||||
| ------------ | -------------------------- | ---------------------- | ------------------- |
|
||||
| Vapora | `typedialog-ai` (path dep) | Claude, OpenAI, Ollama | Shared base |
|
||||
| TypeDialog | `typedialog-ai` (local) | Claude, OpenAI, Ollama | Defines abstraction |
|
||||
| Provisioning | Custom `LlmClient` | Claude, OpenAI | 100% duplicated |
|
||||
| Kogral | `rig-core` | Embeddings only | Different stack |
|
||||
|
||||
### Identified Problems
|
||||
|
||||
#### 1. Code Duplication
|
||||
|
||||
Provisioning reimplements what TypeDialog already has:
|
||||
|
||||
- reqwest HTTP client
|
||||
- Headers: x-api-key, anthropic-version
|
||||
- JSON body formatting
|
||||
- Response parsing
|
||||
- Error handling
|
||||
|
||||
**Impact**: ~500 duplicated lines, bugs fixed in one place don't propagate.
|
||||
|
||||
#### 2. API Keys Only, No CLI Detection
|
||||
|
||||
No project detects credentials from official CLIs:
|
||||
|
||||
```text
|
||||
Claude CLI: ~/.config/claude/credentials.json
|
||||
OpenAI CLI: ~/.config/openai/credentials.json
|
||||
```
|
||||
|
||||
**Impact**: Users with Claude Pro/Max ($20-100/month) pay for API tokens when they could use their subscription.
|
||||
|
||||
#### 3. No Automatic Fallback
|
||||
|
||||
When a provider fails (rate limit, timeout), the request fails completely:
|
||||
|
||||
```text
|
||||
Actual: Request → Claude API → Rate Limit → ERROR
|
||||
Desired: Request → Claude API → Rate Limit → OpenAI → Success
|
||||
```
|
||||
|
||||
#### 4. No Circuit Breaker
|
||||
|
||||
If Claude API is down, each request attempts to connect, fails, and propagates the error:
|
||||
|
||||
```text
|
||||
Request 1 → Claude → Timeout (30s) → Error
|
||||
Request 2 → Claude → Timeout (30s) → Error
|
||||
Request 3 → Claude → Timeout (30s) → Error
|
||||
```
|
||||
|
||||
**Impact**: Accumulated latency, degraded UX.
|
||||
|
||||
#### 5. No Caching
|
||||
|
||||
Identical requests always go to the API:
|
||||
|
||||
```text
|
||||
"Explain this Rust error" → Claude → $0.003
|
||||
"Explain this Rust error" → Claude → $0.003 (same result)
|
||||
```
|
||||
|
||||
**Impact**: Unnecessary costs, especially in development/testing.
|
||||
|
||||
#### 6. Kogral Not Integrated
|
||||
|
||||
Kogral has guidelines and patterns that could enrich LLM context, but there's no integration.
|
||||
|
||||
## Decision
|
||||
|
||||
Create `stratum-llm` as a unified crate that:
|
||||
|
||||
1. **Consolidates** existing implementations from typedialog-ai and provisioning
|
||||
2. **Detects** CLI credentials and subscriptions before using API keys
|
||||
3. **Implements** automatic fallback with circuit breaker
|
||||
4. **Adds** request caching to reduce costs
|
||||
5. **Integrates** Kogral for context enrichment
|
||||
6. **Is used** by all ecosystem projects
|
||||
|
||||
### Architecture
|
||||
|
||||
```text
|
||||
┌─────────────────────────────────────────────────────────┐
|
||||
│ stratum-llm │
|
||||
├─────────────────────────────────────────────────────────┤
|
||||
│ CredentialDetector │
|
||||
│ ├─ Claude CLI → ~/.config/claude/ (subscription) │
|
||||
│ ├─ OpenAI CLI → ~/.config/openai/ │
|
||||
│ ├─ Env vars → *_API_KEY │
|
||||
│ └─ Ollama → localhost:11434 (free) │
|
||||
│ │ │
|
||||
│ ▼ │
|
||||
│ ProviderChain (ordered by priority) │
|
||||
│ [CLI/Sub] → [API] → [DeepSeek] → [Ollama] │
|
||||
│ │ │ │ │ │
|
||||
│ └──────────┴─────────┴───────────┘ │
|
||||
│ │ │
|
||||
│ CircuitBreaker per provider │
|
||||
│ │ │
|
||||
│ RequestCache │
|
||||
│ │ │
|
||||
│ KogralIntegration │
|
||||
│ │ │
|
||||
│ UnifiedClient │
|
||||
│ │
|
||||
└─────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
## Rationale
|
||||
|
||||
### Why Not Use Another External Crate
|
||||
|
||||
| Alternative | Why Not |
|
||||
| -------------- | ------------------------------------------ |
|
||||
| kaccy-ai | Oriented toward blockchain/fraud detection |
|
||||
| llm (crate) | Very basic, no circuit breaker or caching |
|
||||
| langchain-rust | Python port, not idiomatic Rust |
|
||||
| rig-core | Embeddings/RAG only, no chat completion |
|
||||
|
||||
**Best option**: Build on typedialog-ai and add missing features.
|
||||
|
||||
### Why CLI Detection is Important
|
||||
|
||||
Cost analysis for typical user:
|
||||
|
||||
| Scenario | Monthly Cost |
|
||||
| ------------------------- | -------------------- |
|
||||
| API only (current) | ~$840 |
|
||||
| Claude Pro + API overflow | ~$20 + ~$200 = $220 |
|
||||
| Claude Max + API overflow | ~$100 + ~$50 = $150 |
|
||||
|
||||
**Potential savings**: 70-80% by detecting and using subscriptions first.
|
||||
|
||||
### Why Circuit Breaker
|
||||
|
||||
Without circuit breaker, a downed provider causes:
|
||||
|
||||
- N requests × 30s timeout = N×30s total latency
|
||||
- All resources occupied waiting for timeouts
|
||||
|
||||
With circuit breaker:
|
||||
|
||||
- First failure opens circuit
|
||||
- Following requests fail immediately (fast fail)
|
||||
- Fallback to another provider without waiting
|
||||
- Circuit resets after cooldown
|
||||
|
||||
### Why Caching
|
||||
|
||||
For typical development:
|
||||
|
||||
- Same questions repeated while iterating
|
||||
- Testing executes same prompts multiple times
|
||||
|
||||
Estimated cache hit rate: 15-30% in active development.
|
||||
|
||||
### Why Kogral Integration
|
||||
|
||||
Kogral has language guidelines, domain patterns, and ADRs.
|
||||
Without integration the LLM generates generic code;
|
||||
with integration it generates code following project conventions.
|
||||
|
||||
## Consequences
|
||||
|
||||
### Positive
|
||||
|
||||
1. Single source of truth for LLM logic
|
||||
2. CLI detection reduces costs 70-80%
|
||||
3. Circuit breaker + fallback = high availability
|
||||
4. 15-30% fewer requests in development (caching)
|
||||
5. Kogral improves generation quality
|
||||
6. Feature-gated: each feature is optional
|
||||
|
||||
### Negative
|
||||
|
||||
1. **Migration effort**: Refactor Vapora, TypeDialog, Provisioning
|
||||
2. **New dependency**: Projects depend on stratumiops
|
||||
3. **CLI auth complexity**: Different credential formats per version
|
||||
4. **Cache invalidation**: Stale responses if not managed well
|
||||
|
||||
### Mitigations
|
||||
|
||||
| Negative | Mitigation |
|
||||
| ------------------- | ------------------------------------------- |
|
||||
| Migration effort | Re-export compatible API from typedialog-ai |
|
||||
| New dependency | Local path dependency, not crates.io |
|
||||
| CLI auth complexity | Version detection, fallback to API if fails |
|
||||
| Cache invalidation | Configurable TTL, bypass option |
|
||||
|
||||
## Success Metrics
|
||||
|
||||
| Metric | Current | Target |
|
||||
| ------------------------ | ------- | --------------- |
|
||||
| Duplicated lines of code | ~500 | 0 |
|
||||
| CLI credential detection | 0% | 100% |
|
||||
| Fallback success rate | 0% | >90% |
|
||||
| Cache hit rate | 0% | 15-30% |
|
||||
| Latency (provider down) | 30s+ | <1s (fast fail) |
|
||||
|
||||
## Cost Impact Analysis
|
||||
|
||||
Based on real usage data ($840/month):
|
||||
|
||||
| Scenario | Savings |
|
||||
| -------------------------- | ------------------ |
|
||||
| CLI detection (Claude Max) | ~$700/month |
|
||||
| Caching (15% hit rate) | ~$50/month |
|
||||
| DeepSeek fallback for code | ~$100/month |
|
||||
| **Total potential** | **$500-700/month** |
|
||||
|
||||
## Migration Strategy
|
||||
|
||||
### Migration Phases
|
||||
|
||||
1. Create stratum-llm with API compatible with typedialog-ai
|
||||
2. typedialog-ai re-exports stratum-llm (backward compatible)
|
||||
3. Vapora migrates to stratum-llm directly
|
||||
4. Provisioning migrates its LlmClient to stratum-llm
|
||||
5. Deprecate typedialog-ai, consolidate in stratum-llm
|
||||
|
||||
### Feature Adoption
|
||||
|
||||
| Feature | Adoption |
|
||||
| --------------- | ----------------------------------------- |
|
||||
| Basic providers | Immediate (direct replacement) |
|
||||
| CLI detection | Optional, feature flag |
|
||||
| Circuit breaker | Default on |
|
||||
| Caching | Default on, configurable TTL |
|
||||
| Kogral | Feature flag, requires Kogral installed |
|
||||
|
||||
## Alternatives Considered
|
||||
|
||||
### Alternative 1: Improve typedialog-ai In-Place
|
||||
|
||||
**Pros**: No new crate required
|
||||
|
||||
**Cons**: TypeDialog is a specific project, not shared infrastructure
|
||||
|
||||
**Decision**: stratum-llm in stratumiops is better location for cross-project infrastructure.
|
||||
|
||||
### Alternative 2: Use LiteLLM (Python) as Proxy
|
||||
|
||||
**Pros**: Very complete, 100+ providers
|
||||
|
||||
**Cons**: Python dependency, proxy latency, not Rust-native
|
||||
|
||||
**Decision**: Keep pure Rust stack.
|
||||
|
||||
### Alternative 3: Each Project Maintains Its Own Implementation
|
||||
|
||||
**Pros**: Independence
|
||||
|
||||
**Cons**: Duplication, inconsistency, bugs not shared
|
||||
|
||||
**Decision**: Consolidation is better long-term.
|
||||
|
||||
## References
|
||||
|
||||
**Existing Implementations**:
|
||||
|
||||
- TypeDialog: `typedialog/crates/typedialog-ai/`
|
||||
- Vapora: `vapora/crates/vapora-llm-router/`
|
||||
- Provisioning: `provisioning/platform/crates/rag/`
|
||||
|
||||
**Kogral**: `kogral/`
|
||||
|
||||
**Target Location**: `stratumiops/crates/stratum-llm/`
|
||||
22
docs/en/architecture/adrs/README.md
Normal file
22
docs/en/architecture/adrs/README.md
Normal file
@ -0,0 +1,22 @@
|
||||
# ADRs - Architecture Decision Records
|
||||
|
||||
Architecture decision records for the STRATUMIOPS ecosystem.
|
||||
|
||||
## Active ADRs
|
||||
|
||||
| ID | Title | Status |
|
||||
| -------------------------------- | --------------------------------------------- | -------- |
|
||||
| [001](001-stratum-embeddings.md) | Stratum-Embeddings: Unified Embedding Library | Proposed |
|
||||
| [002](002-stratum-llm.md) | Stratum-LLM: Unified LLM Provider Library | Proposed |
|
||||
|
||||
## Statuses
|
||||
|
||||
- **Proposed**: Under review, pending implementation
|
||||
- **Accepted**: Approved and implemented
|
||||
- **Deprecated**: Replaced by another ADR
|
||||
- **Superseded**: Obsolete, see replacement ADR
|
||||
|
||||
## Navigation
|
||||
|
||||
- [Back to architecture](../)
|
||||
- [Spanish version](../../../es/architecture/adrs/)
|
||||
@ -34,6 +34,16 @@ Herramientas de automatización de infraestructura y despliegue.
|
||||
|
||||
Ver directorio [ops/](ops/) para documentación completa del portfolio de operaciones.
|
||||
|
||||
### Arquitectura
|
||||
|
||||
Decisiones arquitecturales y diseño del ecosistema.
|
||||
|
||||
- [**ADRs**](architecture/adrs/) - Architecture Decision Records
|
||||
- [ADR-001: Stratum-Embeddings](architecture/adrs/001-stratum-embeddings.md) - Biblioteca unificada de embeddings
|
||||
- [ADR-002: Stratum-LLM](architecture/adrs/002-stratum-llm.md) - Biblioteca unificada de providers LLM
|
||||
|
||||
Ver directorio [architecture/](architecture/) para documentación completa de arquitectura.
|
||||
|
||||
## Navegación
|
||||
|
||||
- [Volver a documentación raíz](../)
|
||||
|
||||
30
docs/es/architecture/README.md
Normal file
30
docs/es/architecture/README.md
Normal file
@ -0,0 +1,30 @@
|
||||
# Arquitectura
|
||||
|
||||
Documentación de arquitectura del ecosistema STRATUMIOPS.
|
||||
|
||||
## Contenido
|
||||
|
||||
### ADRs (Architecture Decision Records)
|
||||
|
||||
Decisiones arquitecturales documentadas siguiendo el formato ADR:
|
||||
|
||||
- [**ADR-001: Stratum-Embeddings**](adrs/001-stratum-embeddings.md) - Biblioteca unificada de embeddings
|
||||
- [**ADR-002: Stratum-LLM**](adrs/002-stratum-llm.md) - Biblioteca unificada de providers LLM
|
||||
|
||||
## Formato ADR
|
||||
|
||||
Cada ADR sigue la estructura:
|
||||
|
||||
| Sección | Descripción |
|
||||
| ----------------- | ------------------------------------------ |
|
||||
| Estado | Propuesto, Aceptado, Deprecado, Superseded |
|
||||
| Contexto | Problema y estado actual |
|
||||
| Decisión | Solución elegida |
|
||||
| Justificación | Por qué esta solución |
|
||||
| Consecuencias | Positivas, negativas, mitigaciones |
|
||||
| Métricas de Éxito | Cómo medir el resultado |
|
||||
|
||||
## Navegación
|
||||
|
||||
- [Volver a documentación principal](../)
|
||||
- [English version](../../en/architecture/)
|
||||
280
docs/es/architecture/adrs/001-stratum-embeddings.md
Normal file
280
docs/es/architecture/adrs/001-stratum-embeddings.md
Normal file
@ -0,0 +1,280 @@
|
||||
# ADR-001: Stratum-Embeddings - Biblioteca Unificada de Embeddings
|
||||
|
||||
## Estado
|
||||
|
||||
**Propuesto**
|
||||
|
||||
## Contexto
|
||||
|
||||
### Estado Actual: Implementaciones Fragmentadas
|
||||
|
||||
El ecosistema tiene 3 implementaciones independientes de embeddings:
|
||||
|
||||
| Proyecto | Ubicación | Providers | Caching |
|
||||
| ------------ | ------------------------------------- | ----------------------------- | ------- |
|
||||
| Kogral | `kogral-core/src/embeddings/` | fastembed, rig-core (parcial) | No |
|
||||
| Provisioning | `provisioning-rag/src/embeddings.rs` | OpenAI directo | No |
|
||||
| Vapora | `vapora-llm-router/src/embeddings.rs` | OpenAI, HuggingFace, Ollama | No |
|
||||
|
||||
### Problemas Identificados
|
||||
|
||||
#### 1. Código Duplicado
|
||||
|
||||
Cada proyecto reimplementa:
|
||||
|
||||
- HTTP client para OpenAI embeddings
|
||||
- Parsing de respuestas JSON
|
||||
- Manejo de errores
|
||||
- Token estimation
|
||||
|
||||
**Impacto**: ~400 líneas duplicadas, inconsistencias en manejo de errores.
|
||||
|
||||
#### 2. Sin Caching
|
||||
|
||||
Embeddings se regeneran cada vez:
|
||||
|
||||
```text
|
||||
"What is Rust?" → OpenAI → 1536 dims → $0.00002
|
||||
"What is Rust?" → OpenAI → 1536 dims → $0.00002 (mismo resultado)
|
||||
"What is Rust?" → OpenAI → 1536 dims → $0.00002 (mismo resultado)
|
||||
```
|
||||
|
||||
**Impacto**: Costos innecesarios, latencia adicional, rate limits más frecuentes.
|
||||
|
||||
#### 3. No Hay Fallback
|
||||
|
||||
Si OpenAI falla, todo falla. No hay fallback a alternativas locales (fastembed, Ollama).
|
||||
|
||||
**Impacto**: Disponibilidad reducida, dependencia total de un provider.
|
||||
|
||||
#### 4. Dimension Mismatch Silencioso
|
||||
|
||||
Diferentes providers producen diferentes dimensiones:
|
||||
|
||||
| Provider | Modelo | Dimensiones |
|
||||
| --------- | ---------------------- | ----------- |
|
||||
| fastembed | bge-small-en | 384 |
|
||||
| fastembed | bge-large-en | 1024 |
|
||||
| OpenAI | text-embedding-3-small | 1536 |
|
||||
| OpenAI | text-embedding-3-large | 3072 |
|
||||
| Ollama | nomic-embed-text | 768 |
|
||||
|
||||
**Impacto**: Índices vectoriales corruptos si se cambia de provider.
|
||||
|
||||
#### 5. Sin Métricas
|
||||
|
||||
No hay visibilidad de uso, hit rate de cache, latencia por provider, ni costos acumulados.
|
||||
|
||||
## Decisión
|
||||
|
||||
Crear `stratum-embeddings` como crate unificado que:
|
||||
|
||||
1. **Unifique** las implementaciones de Kogral, Provisioning, y Vapora
|
||||
2. **Añada caching** para evitar re-computar embeddings idénticos
|
||||
3. **Implemente fallback** entre providers (cloud → local)
|
||||
4. **Documente claramente** las dimensiones y limitaciones por provider
|
||||
5. **Exponga métricas** para observabilidad
|
||||
6. **Provea VectorStore trait** con backends LanceDB y SurrealDB según necesidad del proyecto
|
||||
|
||||
### Decisión de Backend de Storage
|
||||
|
||||
Cada proyecto elige su backend de vector storage según su prioridad:
|
||||
|
||||
| Proyecto | Backend | Prioridad | Justificación |
|
||||
| ------------ | --------- | ----------------- | -------------------------------------------------------- |
|
||||
| Kogral | SurrealDB | Riqueza del grafo | Knowledge Graph necesita queries unificados graph+vector |
|
||||
| Provisioning | LanceDB | Escala vectorial | RAG con millones de chunks documentales |
|
||||
| Vapora | LanceDB | Escala vectorial | Traces de ejecución, pattern matching a escala |
|
||||
|
||||
#### Por qué SurrealDB para Kogral
|
||||
|
||||
Kogral es un Knowledge Graph donde las relaciones son el valor principal.
|
||||
Con arquitectura híbrida (LanceDB vectores + SurrealDB graph), un query típico requeriría:
|
||||
|
||||
1. LanceDB: búsqueda vectorial → candidate_ids
|
||||
2. SurrealDB: filtro de grafo sobre candidates → results
|
||||
3. App layer: merge, re-rank, deduplicación
|
||||
|
||||
**Trade-off aceptado**: SurrealDB tiene peor rendimiento vectorial puro que LanceDB,
|
||||
pero la escala de Kogral está limitada por curación humana del conocimiento
|
||||
(10K-100K conceptos típicamente).
|
||||
|
||||
#### Por qué LanceDB para Provisioning y Vapora
|
||||
|
||||
| Aspecto | SurrealDB | LanceDB |
|
||||
| --------------- | ---------- | -------------------- |
|
||||
| Storage format | Row-based | Columnar (Lance) |
|
||||
| Vector index | HNSW (RAM) | IVF-PQ (disk-native) |
|
||||
| Escala práctica | Millones | Billones |
|
||||
| Compresión | ~1x | ~32x (PQ) |
|
||||
| Zero-copy read | No | Sí |
|
||||
|
||||
### Arquitectura
|
||||
|
||||
```text
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ stratum-embeddings │
|
||||
├─────────────────────────────────────────────────────────────────┤
|
||||
│ EmbeddingProvider trait │
|
||||
│ ├─ embed(text) → Vec<f32> │
|
||||
│ ├─ embed_batch(texts) → Vec<Vec<f32>> │
|
||||
│ ├─ dimensions() → usize │
|
||||
│ └─ is_local() → bool │
|
||||
│ │
|
||||
│ ┌───────────┐ ┌───────────┐ ┌───────────┐ │
|
||||
│ │ FastEmbed │ │ OpenAI │ │ Ollama │ │
|
||||
│ │ (local) │ │ (cloud) │ │ (local) │ │
|
||||
│ └───────────┘ └───────────┘ └───────────┘ │
|
||||
│ └────────────┬────────────┘ │
|
||||
│ ▼ │
|
||||
│ EmbeddingCache (memory/disk) │
|
||||
│ │ │
|
||||
│ ▼ │
|
||||
│ EmbeddingService │
|
||||
│ │ │
|
||||
│ ▼ │
|
||||
│ ┌──────────────────────────────────────────────────────────┐ │
|
||||
│ │ VectorStore trait │ │
|
||||
│ │ ├─ upsert(id, embedding, metadata) │ │
|
||||
│ │ ├─ search(embedding, limit, filter) → Vec<Match> │ │
|
||||
│ │ └─ delete(id) │ │
|
||||
│ └──────────────────────────────────────────────────────────┘ │
|
||||
│ │ │ │
|
||||
│ ▼ ▼ │
|
||||
│ ┌─────────────────┐ ┌─────────────────┐ │
|
||||
│ │ SurrealDbStore │ │ LanceDbStore │ │
|
||||
│ │ (Kogral) │ │ (Prov/Vapora) │ │
|
||||
│ └─────────────────┘ └─────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
## Justificación
|
||||
|
||||
### Por Qué Caching es Crítico
|
||||
|
||||
Para un sistema RAG típico (10,000 chunks):
|
||||
|
||||
- **Sin cache**: Re-indexaciones y queries repetidas multiplican costos
|
||||
- **Con cache**: Primera indexación paga, resto son cache hits
|
||||
|
||||
**Ahorro estimado**: 60-80% en costos de embeddings.
|
||||
|
||||
### Por Qué Fallback es Importante
|
||||
|
||||
| Escenario | Sin Fallback | Con Fallback |
|
||||
| ----------------- | ------------ | ------------------- |
|
||||
| OpenAI rate limit | ERROR | → fastembed (local) |
|
||||
| OpenAI downtime | ERROR | → Ollama (local) |
|
||||
| Sin internet | ERROR | → fastembed (local) |
|
||||
|
||||
### Por Qué Providers Locales Primero
|
||||
|
||||
Para desarrollo: fastembed carga modelo local (~100MB), no requiere API keys, sin costos, funciona offline.
|
||||
|
||||
Para producción: OpenAI para calidad, fastembed como fallback.
|
||||
|
||||
## Consecuencias
|
||||
|
||||
### Positivas
|
||||
|
||||
1. Single source of truth para todo el ecosistema
|
||||
2. 60-80% menos llamadas a APIs de embeddings (caching)
|
||||
3. Alta disponibilidad con providers locales (fallback)
|
||||
4. Métricas de uso y costos
|
||||
5. Feature-gated: solo compila lo necesario
|
||||
6. Storage flexibility: VectorStore trait permite elegir backend por proyecto
|
||||
|
||||
### Negativas
|
||||
|
||||
1. **Dimension lock-in**: Cambiar provider requiere re-indexar
|
||||
2. **Cache invalidation**: Contenido actualizado puede servir embeddings stale
|
||||
3. **Model download**: fastembed descarga ~100MB en primer uso
|
||||
4. **Storage lock-in por proyecto**: Kogral atado a SurrealDB, otros a LanceDB
|
||||
|
||||
### Mitigaciones
|
||||
|
||||
| Negativo | Mitigación |
|
||||
| ----------------- | ------------------------------------------------------ |
|
||||
| Dimension lock-in | Documentar claramente, warn en cambio de provider |
|
||||
| Cache stale | TTL configurable, opción de bypass |
|
||||
| Model download | Mostrar progreso, cache en ~/.cache/fastembed |
|
||||
| Storage lock-in | Decisión consciente basada en prioridades del proyecto |
|
||||
|
||||
## Métricas de Éxito
|
||||
|
||||
| Métrica | Actual | Objetivo |
|
||||
| --------------------------- | ------ | -------- |
|
||||
| Implementaciones duplicadas | 3 | 1 |
|
||||
| Cache hit rate | 0% | >60% |
|
||||
| Fallback availability | 0% | 100% |
|
||||
| Cost per 10K embeddings | ~$0.20 | ~$0.05 |
|
||||
|
||||
## Guía de Selección de Provider
|
||||
|
||||
### Desarrollo
|
||||
|
||||
```rust
|
||||
// Local, gratis, offline
|
||||
let service = EmbeddingService::builder()
|
||||
.with_provider(FastEmbedProvider::small()?) // 384 dims
|
||||
.with_memory_cache()
|
||||
.build()?;
|
||||
```
|
||||
|
||||
### Producción (Calidad)
|
||||
|
||||
```rust
|
||||
// OpenAI con fallback local
|
||||
let service = EmbeddingService::builder()
|
||||
.with_provider(OpenAiEmbeddingProvider::large()?) // 3072 dims
|
||||
.with_provider(FastEmbedProvider::large()?) // Fallback
|
||||
.with_memory_cache()
|
||||
.build()?;
|
||||
```
|
||||
|
||||
### Producción (Costo-Optimizado)
|
||||
|
||||
```rust
|
||||
// OpenAI small con fallback
|
||||
let service = EmbeddingService::builder()
|
||||
.with_provider(OpenAiEmbeddingProvider::small()?) // 1536 dims
|
||||
.with_provider(OllamaEmbeddingProvider::nomic()) // Fallback
|
||||
.with_memory_cache()
|
||||
.build()?;
|
||||
```
|
||||
|
||||
## Matriz de Compatibilidad de Dimensiones
|
||||
|
||||
| Si usas... | Puedes cambiar a... | NO puedes cambiar a... |
|
||||
| ---------------------- | --------------------------- | ---------------------- |
|
||||
| fastembed small (384) | fastembed small, all-minilm | Cualquier otro |
|
||||
| fastembed large (1024) | fastembed large | Cualquier otro |
|
||||
| OpenAI small (1536) | OpenAI small, ada-002 | Cualquier otro |
|
||||
| OpenAI large (3072) | OpenAI large | Cualquier otro |
|
||||
|
||||
**Regla**: Solo puedes cambiar entre modelos con las MISMAS dimensiones.
|
||||
|
||||
## Prioridad de Implementación
|
||||
|
||||
| Orden | Feature | Razón |
|
||||
| ----- | ----------------------- | ---------------------------- |
|
||||
| 1 | EmbeddingProvider trait | Base para todo |
|
||||
| 2 | FastEmbed provider | Funciona sin API keys |
|
||||
| 3 | Memory cache | Mayor impacto en costos |
|
||||
| 4 | VectorStore trait | Abstracción de storage |
|
||||
| 5 | SurrealDbStore | Kogral necesita graph+vector |
|
||||
| 6 | LanceDbStore | Provisioning/Vapora escala |
|
||||
| 7 | OpenAI provider | Producción |
|
||||
| 8 | Ollama provider | Fallback local |
|
||||
| 9 | Batch processing | Eficiencia |
|
||||
| 10 | Metrics | Observabilidad |
|
||||
|
||||
## Referencias
|
||||
|
||||
**Implementaciones Existentes**:
|
||||
|
||||
- Kogral: `kogral-core/src/embeddings/`
|
||||
- Vapora: `vapora-llm-router/src/embeddings.rs`
|
||||
- Provisioning: `provisioning/platform/crates/rag/src/embeddings.rs`
|
||||
|
||||
**Ubicación Objetivo**: `stratumiops/crates/stratum-embeddings/`
|
||||
279
docs/es/architecture/adrs/002-stratum-llm.md
Normal file
279
docs/es/architecture/adrs/002-stratum-llm.md
Normal file
@ -0,0 +1,279 @@
|
||||
# ADR-002: Stratum-LLM - Biblioteca Unificada de Providers LLM
|
||||
|
||||
## Estado
|
||||
|
||||
**Propuesto**
|
||||
|
||||
## Contexto
|
||||
|
||||
### Estado Actual: Conexiones LLM Fragmentadas
|
||||
|
||||
El ecosistema stratumiops tiene 4 proyectos con funcionalidad IA, cada uno con su propia implementación:
|
||||
|
||||
| Proyecto | Implementación | Providers | Duplicación |
|
||||
| ------------ | -------------------------- | ---------------------- | --------------------- |
|
||||
| Vapora | `typedialog-ai` (path dep) | Claude, OpenAI, Ollama | Base compartida |
|
||||
| TypeDialog | `typedialog-ai` (local) | Claude, OpenAI, Ollama | Define la abstracción |
|
||||
| Provisioning | Custom `LlmClient` | Claude, OpenAI | 100% duplicado |
|
||||
| Kogral | `rig-core` | Solo embeddings | Diferente stack |
|
||||
|
||||
### Problemas Identificados
|
||||
|
||||
#### 1. Duplicación de Código
|
||||
|
||||
Provisioning reimplementa lo que TypeDialog ya tiene:
|
||||
|
||||
- reqwest HTTP client
|
||||
- Headers: x-api-key, anthropic-version
|
||||
- JSON body formatting
|
||||
- Response parsing
|
||||
- Error handling
|
||||
|
||||
**Impacto**: ~500 líneas duplicadas, bugs arreglados en un lugar no se propagan.
|
||||
|
||||
#### 2. Solo API Keys, No CLI Detection
|
||||
|
||||
Ningún proyecto detecta credenciales de CLIs oficiales:
|
||||
|
||||
```text
|
||||
Claude CLI: ~/.config/claude/credentials.json
|
||||
OpenAI CLI: ~/.config/openai/credentials.json
|
||||
```
|
||||
|
||||
**Impacto**: Usuarios con Claude Pro/Max ($20-100/mes) pagan API tokens cuando podrían usar su suscripción.
|
||||
|
||||
#### 3. Sin Fallback Automático
|
||||
|
||||
Cuando un provider falla (rate limit, timeout), la request falla completamente:
|
||||
|
||||
```text
|
||||
Actual: Request → Claude API → Rate Limit → ERROR
|
||||
Deseado: Request → Claude API → Rate Limit → OpenAI → Success
|
||||
```
|
||||
|
||||
#### 4. Sin Circuit Breaker
|
||||
|
||||
Si Claude API está caído, cada request intenta conectar, falla, y propaga el error:
|
||||
|
||||
```text
|
||||
Request 1 → Claude → Timeout (30s) → Error
|
||||
Request 2 → Claude → Timeout (30s) → Error
|
||||
Request 3 → Claude → Timeout (30s) → Error
|
||||
```
|
||||
|
||||
**Impacto**: Latencia acumulada, UX degradado.
|
||||
|
||||
#### 5. Sin Caching
|
||||
|
||||
Requests idénticas van siempre a la API:
|
||||
|
||||
```text
|
||||
"Explain this Rust error" → Claude → $0.003
|
||||
"Explain this Rust error" → Claude → $0.003 (mismo resultado)
|
||||
```
|
||||
|
||||
**Impacto**: Costos innecesarios, especialmente en desarrollo/testing.
|
||||
|
||||
#### 6. Kogral No Integrado
|
||||
|
||||
Kogral tiene guidelines y patterns que podrían enriquecer el contexto de LLM, pero no hay integración.
|
||||
|
||||
## Decisión
|
||||
|
||||
Crear `stratum-llm` como crate unificado que:
|
||||
|
||||
1. **Consolide** las implementaciones existentes de typedialog-ai y provisioning
|
||||
2. **Detecte** credenciales CLI y subscripciones antes de usar API keys
|
||||
3. **Implemente** fallback automático con circuit breaker
|
||||
4. **Añada** caching de requests para reducir costos
|
||||
5. **Integre** Kogral para enriquecer contexto
|
||||
6. **Sea usado** por todos los proyectos del ecosistema
|
||||
|
||||
### Arquitectura
|
||||
|
||||
```text
|
||||
┌─────────────────────────────────────────────────────────┐
|
||||
│ stratum-llm │
|
||||
├─────────────────────────────────────────────────────────┤
|
||||
│ CredentialDetector │
|
||||
│ ├─ Claude CLI → ~/.config/claude/ (subscription) │
|
||||
│ ├─ OpenAI CLI → ~/.config/openai/ │
|
||||
│ ├─ Env vars → *_API_KEY │
|
||||
│ └─ Ollama → localhost:11434 (free) │
|
||||
│ │ │
|
||||
│ ▼ │
|
||||
│ ProviderChain (ordered by priority) │
|
||||
│ [CLI/Sub] → [API] → [DeepSeek] → [Ollama] │
|
||||
│ │ │ │ │ │
|
||||
│ └──────────┴─────────┴───────────┘ │
|
||||
│ │ │
|
||||
│ CircuitBreaker per provider │
|
||||
│ │ │
|
||||
│ RequestCache │
|
||||
│ │ │
|
||||
│ KogralIntegration │
|
||||
│ │ │
|
||||
│ UnifiedClient │
|
||||
│ │
|
||||
└─────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
## Justificación
|
||||
|
||||
### Por Qué No Usar Otra Crate Externa
|
||||
|
||||
| Alternativa | Por Qué No |
|
||||
| -------------- | ------------------------------------------ |
|
||||
| kaccy-ai | Orientada a blockchain/fraud detection |
|
||||
| llm (crate) | Muy básica, sin circuit breaker ni caching |
|
||||
| langchain-rust | Port de Python, no idiomático Rust |
|
||||
| rig-core | Solo embeddings/RAG, no chat completion |
|
||||
|
||||
**Mejor opción**: Construir sobre typedialog-ai y añadir features faltantes.
|
||||
|
||||
### Por Qué CLI Detection es Importante
|
||||
|
||||
Análisis de costos para usuario típico:
|
||||
|
||||
| Escenario | Costo Mensual |
|
||||
| ------------------------- | -------------------- |
|
||||
| Solo API (actual) | ~$840 |
|
||||
| Claude Pro + API overflow | ~$20 + ~$200 = $220 |
|
||||
| Claude Max + API overflow | ~$100 + ~$50 = $150 |
|
||||
|
||||
**Ahorro potencial**: 70-80% detectando y usando subscripciones primero.
|
||||
|
||||
### Por Qué Circuit Breaker
|
||||
|
||||
Sin circuit breaker, un provider caído causa:
|
||||
|
||||
- N requests × 30s timeout = N×30s de latencia total
|
||||
- Todos los recursos ocupados esperando timeouts
|
||||
|
||||
Con circuit breaker:
|
||||
|
||||
- Primera falla abre circuito
|
||||
- Siguientes requests fallan inmediatamente (fast fail)
|
||||
- Fallback a otro provider sin esperar
|
||||
- Circuito se resetea después de cooldown
|
||||
|
||||
### Por Qué Caching
|
||||
|
||||
Para desarrollo típico:
|
||||
|
||||
- Mismas preguntas repetidas mientras se itera
|
||||
- Testing ejecuta mismos prompts múltiples veces
|
||||
|
||||
Cache hit rate estimado: 15-30% en desarrollo activo.
|
||||
|
||||
### Por Qué Kogral Integration
|
||||
|
||||
Kogral tiene guidelines por lenguaje, patterns por dominio, y ADRs.
|
||||
Sin integración el LLM genera código genérico;
|
||||
con integración genera código que sigue convenciones del proyecto.
|
||||
|
||||
## Consecuencias
|
||||
|
||||
### Positivas
|
||||
|
||||
1. Single source of truth para lógica de LLM
|
||||
2. CLI detection reduce costos 70-80%
|
||||
3. Circuit breaker + fallback = alta disponibilidad
|
||||
4. 15-30% menos requests en desarrollo (caching)
|
||||
5. Kogral mejora calidad de generación
|
||||
6. Feature-gated: cada feature es opcional
|
||||
|
||||
### Negativas
|
||||
|
||||
1. **Migration effort**: Refactorizar Vapora, TypeDialog, Provisioning
|
||||
2. **New dependency**: Proyectos dependen de stratumiops
|
||||
3. **CLI auth complexity**: Diferentes formatos de credenciales por versión
|
||||
4. **Cache invalidation**: Respuestas obsoletas si no se gestiona bien
|
||||
|
||||
### Mitigaciones
|
||||
|
||||
| Negativo | Mitigación |
|
||||
| ------------------- | -------------------------------------------- |
|
||||
| Migration effort | Re-export API compatible desde typedialog-ai |
|
||||
| New dependency | Path dependency local, no crates.io |
|
||||
| CLI auth complexity | Version detection, fallback a API si falla |
|
||||
| Cache invalidation | TTL configurable, opción de bypass |
|
||||
|
||||
## Métricas de Éxito
|
||||
|
||||
| Métrica | Actual | Objetivo |
|
||||
| --------------------------- | ------ | --------------- |
|
||||
| Líneas de código duplicadas | ~500 | 0 |
|
||||
| CLI credential detection | 0% | 100% |
|
||||
| Fallback success rate | 0% | >90% |
|
||||
| Cache hit rate | 0% | 15-30% |
|
||||
| Latency (provider down) | 30s+ | <1s (fast fail) |
|
||||
|
||||
## Análisis de Impacto en Costos
|
||||
|
||||
Basado en datos reales de uso ($840/mes):
|
||||
|
||||
| Escenario | Ahorro |
|
||||
| ----------------------------- | ---------------- |
|
||||
| CLI detection (Claude Max) | ~$700/mes |
|
||||
| Caching (15% hit rate) | ~$50/mes |
|
||||
| DeepSeek fallback para código | ~$100/mes |
|
||||
| **Total potencial** | **$500-700/mes** |
|
||||
|
||||
## Estrategia de Migración
|
||||
|
||||
### Fases de Migración
|
||||
|
||||
1. Crear stratum-llm con API compatible con typedialog-ai
|
||||
2. typedialog-ai re-exporta stratum-llm (backward compatible)
|
||||
3. Vapora migra a stratum-llm directamente
|
||||
4. Provisioning migra su LlmClient a stratum-llm
|
||||
5. Deprecar typedialog-ai, consolidar en stratum-llm
|
||||
|
||||
### Adopción de Features
|
||||
|
||||
| Feature | Adopción |
|
||||
| --------------- | --------------------------------------- |
|
||||
| Basic providers | Inmediata (reemplazo directo) |
|
||||
| CLI detection | Opcional, feature flag |
|
||||
| Circuit breaker | Default on |
|
||||
| Caching | Default on, configurable TTL |
|
||||
| Kogral | Feature flag, requiere Kogral instalado |
|
||||
|
||||
## Alternativas Consideradas
|
||||
|
||||
### Alternativa 1: Mejorar typedialog-ai In-Place
|
||||
|
||||
**Pros**: No requiere nuevo crate
|
||||
|
||||
**Cons**: TypeDialog es proyecto específico, no infraestructura compartida
|
||||
|
||||
**Decisión**: stratum-llm en stratumiops es mejor ubicación para infraestructura cross-project.
|
||||
|
||||
### Alternativa 2: Usar LiteLLM (Python) como Proxy
|
||||
|
||||
**Pros**: Muy completo, 100+ providers
|
||||
|
||||
**Cons**: Dependencia Python, latencia de proxy, no Rust-native
|
||||
|
||||
**Decisión**: Mantener stack Rust puro.
|
||||
|
||||
### Alternativa 3: Cada Proyecto Mantiene su Implementación
|
||||
|
||||
**Pros**: Independencia
|
||||
|
||||
**Cons**: Duplicación, inconsistencia, bugs no compartidos
|
||||
|
||||
**Decisión**: Consolidar es mejor a largo plazo.
|
||||
|
||||
## Referencias
|
||||
|
||||
**Implementaciones Existentes**:
|
||||
|
||||
- TypeDialog: `typedialog/crates/typedialog-ai/`
|
||||
- Vapora: `vapora/crates/vapora-llm-router/`
|
||||
- Provisioning: `provisioning/platform/crates/rag/`
|
||||
|
||||
**Kogral**: `kogral/`
|
||||
|
||||
**Ubicación Objetivo**: `stratumiops/crates/stratum-llm/`
|
||||
22
docs/es/architecture/adrs/README.md
Normal file
22
docs/es/architecture/adrs/README.md
Normal file
@ -0,0 +1,22 @@
|
||||
# ADRs - Architecture Decision Records
|
||||
|
||||
Registro de decisiones arquitecturales del ecosistema STRATUMIOPS.
|
||||
|
||||
## ADRs Activos
|
||||
|
||||
| ID | Título | Estado |
|
||||
| -------------------------------- | ------------------------------------------------------ | --------- |
|
||||
| [001](001-stratum-embeddings.md) | Stratum-Embeddings: Biblioteca Unificada de Embeddings | Propuesto |
|
||||
| [002](002-stratum-llm.md) | Stratum-LLM: Biblioteca Unificada de Providers LLM | Propuesto |
|
||||
|
||||
## Estados
|
||||
|
||||
- **Propuesto**: En revisión, pendiente de implementación
|
||||
- **Aceptado**: Aprobado e implementado
|
||||
- **Deprecado**: Reemplazado por otro ADR
|
||||
- **Superseded**: Obsoleto, ver ADR de reemplazo
|
||||
|
||||
## Navegación
|
||||
|
||||
- [Volver a arquitectura](../)
|
||||
- [English version](../../../en/architecture/adrs/)
|
||||
Loading…
x
Reference in New Issue
Block a user