chore: create stratum-embeddings and stratum-llm crates, docs
Some checks failed
Rust CI / Security Audit (push) Has been cancelled
Rust CI / Check + Test + Lint (nightly) (push) Has been cancelled
Rust CI / Check + Test + Lint (stable) (push) Has been cancelled

This commit is contained in:
Jesús Pérez 2026-01-24 02:03:12 +00:00
parent b0d039d22d
commit 0ae853c2fa
Signed by: jesus
GPG Key ID: 9F243E355E0BC939
70 changed files with 19516 additions and 2 deletions

1
.gitignore vendored
View File

@ -1,6 +1,7 @@
CLAUDE.md CLAUDE.md
.claude .claude
utils/save*sh utils/save*sh
.fastembed_cache
COMMIT_MESSAGE.md COMMIT_MESSAGE.md
.wrks .wrks
nushell nushell

37
CHANGELOG.md Normal file
View File

@ -0,0 +1,37 @@
# Changelog
All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
### Added
- **Architecture Documentation**: New `docs/*/architecture/` section with ADRs
- ADR-001: stratum-embeddings - Unified embedding library with caching, fallback,
and VectorStore trait (SurrealDB for Kogral, LanceDB for Provisioning/Vapora)
- ADR-002: stratum-llm - Unified LLM provider library with CLI credential detection,
circuit breaker, caching, and Kogral integration
- **Bilingual ADRs**: Full English and Spanish versions of all architecture documents
- **README updates**: Added Stratum Crates section and updated documentation structure
### Changed
- Documentation structure now includes `architecture/adrs/` subdirectory in both
language directories (en/es)
## [0.1.0] - 2026-01-22
### Added
- Initial repository setup
- Main documentation structure (bilingual en/es)
- Branding assets (logos, icons, social variants)
- CI/CD configuration (GitHub Actions, Woodpecker)
- Language guidelines (Rust, Nickel, Nushell, Bash)
- Pre-commit hooks configuration
[Unreleased]: https://repo.jesusperez.pro/jesus/stratumiops/compare/v0.1.0...HEAD
[0.1.0]: https://repo.jesusperez.pro/jesus/stratumiops/releases/tag/v0.1.0

10658
Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

60
Cargo.toml Normal file
View File

@ -0,0 +1,60 @@
[workspace]
members = ["crates/*"]
resolver = "2"
[workspace.package]
edition = "2021"
license = "MIT OR Apache-2.0"
[workspace.dependencies]
# Async runtime
tokio = { version = "1.49", features = ["full"] }
async-trait = "0.1"
futures = "0.3"
# HTTP client
reqwest = { version = "0.13", features = ["json"] }
# Serialization
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
serde_yaml = "0.9"
humantime-serde = "1.1"
# Caching
moka = { version = "0.12", features = ["future"] }
sled = "0.34"
# Embeddings
fastembed = "5.8"
# Vector storage
lancedb = "0.23"
surrealdb = { version = "2.5", features = ["kv-mem"] }
# LOCKED: Arrow 56.x required for LanceDB 0.23 compatibility
# LanceDB 0.23 uses Arrow 56.2.0 internally - Arrow 57 breaks API compatibility
# DO NOT upgrade to Arrow 57 until LanceDB supports it
arrow = "=56"
# Error handling
thiserror = "2.0"
anyhow = "1.0"
# Logging and tracing
tracing = "0.1"
tracing-subscriber = "0.3"
# Metrics
prometheus = "0.14"
# Utilities
xxhash-rust = { version = "0.8", features = ["xxh3"] }
dirs = "6.0"
chrono = "0.4"
uuid = "1.19"
which = "8.0"
# Testing
tokio-test = "0.4"
approx = "0.5"
tempfile = "3.24"

View File

@ -131,16 +131,29 @@ StratumIOps is not a single project. It's the **orchestration layer** that coord
- **Integration Patterns**: How projects work together - **Integration Patterns**: How projects work together
- **Shared Standards**: Language guidelines (Rust, Nickel, Nushell, Bash) - **Shared Standards**: Language guidelines (Rust, Nickel, Nushell, Bash)
### Stratum Crates
Shared infrastructure libraries for the ecosystem:
| Crate | Description | Status |
| ----- | ----------- | ------ |
| **stratum-embeddings** | Unified embedding providers with caching, fallback, and VectorStore trait | Proposed |
| **stratum-llm** | Unified LLM providers with CLI detection, circuit breaker, and caching | Proposed |
See [Architecture ADRs](docs/en/architecture/adrs/) for detailed design decisions.
### Documentation Structure ### Documentation Structure
```text ```text
docs/ docs/
├── en/ # English documentation ├── en/ # English documentation
│ ├── ia/ # AI/Development track │ ├── ia/ # AI/Development track
│ └── ops/ # Ops/DevOps track │ ├── ops/ # Ops/DevOps track
│ └── architecture/ # Architecture decisions (ADRs)
└── es/ # Spanish documentation └── es/ # Spanish documentation
├── ia/ # AI/Development track ├── ia/ # AI/Development track
└── ops/ # Ops/DevOps track ├── ops/ # Ops/DevOps track
└── architecture/ # Architecture decisions (ADRs)
``` ```
### Branding Assets ### Branding Assets

View File

@ -0,0 +1,113 @@
[package]
name = "stratum-embeddings"
version = "0.1.0"
edition.workspace = true
description = "Unified embedding providers with caching, batch processing, and vector storage"
license.workspace = true
[dependencies]
# Async runtime
tokio = { workspace = true }
async-trait = { workspace = true }
futures = { workspace = true }
# HTTP client (for cloud providers)
reqwest = { workspace = true, optional = true }
# Serialization
serde = { workspace = true }
serde_json = { workspace = true }
humantime-serde = { workspace = true }
# Caching
moka = { workspace = true }
# Persistent cache (optional)
sled = { workspace = true, optional = true }
# Local embeddings
fastembed = { workspace = true, optional = true }
# Vector storage backends
lancedb = { workspace = true, optional = true }
surrealdb = { workspace = true, optional = true }
arrow = { workspace = true, optional = true }
# Error handling
thiserror = { workspace = true }
# Logging
tracing = { workspace = true }
# Metrics
prometheus = { workspace = true, optional = true }
# Utilities
xxhash-rust = { workspace = true }
[features]
default = ["fastembed-provider", "memory-cache"]
# Providers
fastembed-provider = ["fastembed"]
openai-provider = ["reqwest"]
ollama-provider = ["reqwest"]
cohere-provider = ["reqwest"]
voyage-provider = ["reqwest"]
huggingface-provider = ["reqwest"]
all-providers = [
"fastembed-provider",
"openai-provider",
"ollama-provider",
"cohere-provider",
"voyage-provider",
"huggingface-provider",
]
# Cache backends
memory-cache = []
persistent-cache = ["sled"]
all-cache = ["memory-cache", "persistent-cache"]
# Vector storage backends
lancedb-store = ["lancedb", "arrow"]
surrealdb-store = ["surrealdb"]
all-stores = ["lancedb-store", "surrealdb-store"]
# Observability
metrics = ["prometheus"]
# Project-specific presets
kogral = ["fastembed-provider", "memory-cache", "surrealdb-store"]
provisioning = ["openai-provider", "memory-cache", "lancedb-store"]
vapora = ["all-providers", "memory-cache", "lancedb-store"] # Includes huggingface-provider
# Full feature set
full = ["all-providers", "all-cache", "all-stores", "metrics"]
[dev-dependencies]
tokio-test = { workspace = true }
approx = { workspace = true }
tempfile = { workspace = true }
tracing-subscriber = { workspace = true }
# Example-specific feature requirements
[[example]]
name = "basic_usage"
required-features = ["fastembed-provider"]
[[example]]
name = "fallback_demo"
required-features = ["ollama-provider", "fastembed-provider"]
[[example]]
name = "lancedb_usage"
required-features = ["lancedb-store", "fastembed-provider"]
[[example]]
name = "surrealdb_usage"
required-features = ["surrealdb-store", "fastembed-provider"]
[[example]]
name = "huggingface_usage"
required-features = ["huggingface-provider"]

View File

@ -0,0 +1,180 @@
# stratum-embeddings
Unified embedding providers with caching, batch processing, and vector storage for the STRATUMIOPS ecosystem.
## Features
- **Multiple Providers**: FastEmbed (local), OpenAI, Ollama
- **Smart Caching**: In-memory caching with configurable TTL
- **Batch Processing**: Efficient batch embedding with automatic chunking
- **Vector Storage**: LanceDB (scale-first) and SurrealDB (graph-first)
- **Fallback Support**: Automatic failover between providers
- **Feature Flags**: Modular compilation for minimal dependencies
## Architecture
```text
┌─────────────────────────────────────────┐
│ EmbeddingService │
│ (facade with caching + fallback) │
└─────────────┬───────────────────────────┘
┌─────────┴─────────┐
▼ ▼
┌─────────────┐ ┌─────────────┐
│ Providers │ │ Cache │
│ │ │ │
│ • FastEmbed │ │ • Memory │
│ • OpenAI │ │ • (Sled) │
│ • Ollama │ │ │
└─────────────┘ └─────────────┘
```
## Quick Start
### Basic Usage
```rust
use stratum_embeddings::{
EmbeddingService, FastEmbedProvider, MemoryCache, EmbeddingOptions
};
use std::time::Duration;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let provider = FastEmbedProvider::small()?;
let cache = MemoryCache::new(1000, Duration::from_secs(300));
let service = EmbeddingService::new(provider).with_cache(cache);
let options = EmbeddingOptions::default_with_cache();
let embedding = service.embed("Hello world", &options).await?;
println!("Generated {} dimensions", embedding.len());
Ok(())
}
```
### Batch Processing
```rust
let texts = vec![
"Text 1".to_string(),
"Text 2".to_string(),
"Text 3".to_string(),
];
let result = service.embed_batch(texts, &options).await?;
println!("Embeddings: {}, Cached: {}",
result.embeddings.len(),
result.cached_count
);
```
### Vector Storage
#### LanceDB (Provisioning, Vapora)
```rust
use stratum_embeddings::{LanceDbStore, VectorStore, VectorStoreConfig};
let config = VectorStoreConfig::new(384);
let store = LanceDbStore::new("./data", "embeddings", config).await?;
store.upsert("doc1", &embedding, metadata).await?;
let results = store.search(&query_embedding, 10, None).await?;
```
#### SurrealDB (Kogral)
```rust
use stratum_embeddings::{SurrealDbStore, VectorStore, VectorStoreConfig};
let config = VectorStoreConfig::new(384);
let store = SurrealDbStore::new_memory("concepts", config).await?;
store.upsert("concept1", &embedding, metadata).await?;
let results = store.search(&query_embedding, 10, None).await?;
```
## Feature Flags
### Providers
- `fastembed-provider` (default) - Local embeddings via fastembed
- `openai-provider` - OpenAI API embeddings
- `ollama-provider` - Ollama local server embeddings
- `all-providers` - All embedding providers
### Cache
- `memory-cache` (default) - In-memory caching with moka
- `persistent-cache` - Persistent cache with sled
- `all-cache` - All cache backends
### Vector Storage
- `lancedb-store` - LanceDB vector storage (columnar, disk-native)
- `surrealdb-store` - SurrealDB vector storage (graph + vector)
- `all-stores` - All storage backends
### Project Presets
- `kogral` - fastembed + memory + surrealdb
- `provisioning` - openai + memory + lancedb
- `vapora` - all-providers + memory + lancedb
- `full` - Everything enabled
## Examples
Run examples with:
```bash
cargo run --example basic_usage --features=default
cargo run --example fallback_demo --features=fastembed-provider,ollama-provider
cargo run --example lancedb_usage --features=lancedb-store
cargo run --example surrealdb_usage --features=surrealdb-store
```
## Provider Comparison
| Provider | Type | Cost | Dimensions | Use Case |
|----------|------|------|------------|----------|
| FastEmbed | Local | Free | 384-1024 | Dev, privacy-first |
| OpenAI | Cloud | $0.02-0.13/1M | 1536-3072 | Production RAG |
| Ollama | Local | Free | 384-1024 | Self-hosted |
## Storage Backend Comparison
| Backend | Best For | Strength | Scale |
|---------|----------|----------|-------|
| LanceDB | RAG, traces | Columnar, IVF-PQ index | Billions |
| SurrealDB | Knowledge graphs | Unified graph+vector queries | Millions |
## Configuration
Environment variables:
```bash
# FastEmbed
FASTEMBED_MODEL=bge-small-en
# OpenAI
OPENAI_API_KEY=sk-...
OPENAI_MODEL=text-embedding-3-small
# Ollama
OLLAMA_MODEL=nomic-embed-text
OLLAMA_BASE_URL=http://localhost:11434
```
## Development
```bash
cargo check -p stratum-embeddings --all-features
cargo test -p stratum-embeddings --all-features
cargo clippy -p stratum-embeddings --all-features -- -D warnings
```
## License
MIT OR Apache-2.0

View File

@ -0,0 +1,346 @@
# HuggingFace Embedding Provider
Provider for HuggingFace Inference API embeddings with support for popular sentence-transformers and BGE models.
## Overview
The HuggingFace provider uses the free Inference API to generate embeddings. It supports:
- **Public Models**: Free access to popular embedding models
- **Custom Models**: Support for any HuggingFace model with feature-extraction pipeline
- **Automatic Caching**: Built-in memory cache reduces API calls
- **Response Normalization**: Optional L2 normalization for similarity search
## Features
- ✅ Zero cost for public models (free Inference API)
- ✅ Support for 5+ popular models out of the box
- ✅ Custom model support with configurable dimensions
- ✅ Automatic retry with exponential backoff
- ✅ Rate limit handling
- ✅ Integration with stratum-embeddings caching layer
## Supported Models
### Predefined Models
| Model | Dimensions | Use Case | Constructor |
|-------|------------|----------|-------------|
| **BAAI/bge-small-en-v1.5** | 384 | General-purpose, efficient | `HuggingFaceProvider::bge_small()` |
| **BAAI/bge-base-en-v1.5** | 768 | Balanced performance | `HuggingFaceProvider::bge_base()` |
| **BAAI/bge-large-en-v1.5** | 1024 | High quality | `HuggingFaceProvider::bge_large()` |
| **sentence-transformers/all-MiniLM-L6-v2** | 384 | Fast, lightweight | `HuggingFaceProvider::all_minilm()` |
| **sentence-transformers/all-mpnet-base-v2** | 768 | Strong baseline | - |
### Custom Models
```rust
let model = HuggingFaceModel::Custom(
"sentence-transformers/paraphrase-MiniLM-L6-v2".to_string(),
384,
);
let provider = HuggingFaceProvider::new(api_key, model)?;
```
## API Rate Limits
### Free Inference API
HuggingFace Inference API has the following rate limits:
| Tier | Requests/Hour | Requests/Day | Max Concurrent |
|------|---------------|--------------|----------------|
| **Anonymous** | 1,000 | 10,000 | 1 |
| **Free Account** | 3,000 | 30,000 | 3 |
| **PRO ($9/mo)** | 10,000 | 100,000 | 10 |
| **Enterprise** | Custom | Custom | Custom |
**Rate Limit Headers**:
```
X-RateLimit-Limit: 3000
X-RateLimit-Remaining: 2999
X-RateLimit-Reset: 1234567890
```
### Rate Limit Handling
The provider automatically handles rate limits with:
1. **Exponential Backoff**: Retries with increasing delays (1s, 2s, 4s, 8s)
2. **Max Retries**: Default 3 retries before failing
3. **Circuit Breaker**: Automatically pauses requests if rate limited repeatedly
4. **Cache Integration**: Reduces API calls by 70-90% for repeated queries
**Configuration**:
```rust
// Default retry config (built-in)
let provider = HuggingFaceProvider::new(api_key, model)?;
// With custom retry (future enhancement)
let provider = HuggingFaceProvider::new(api_key, model)?
.with_retry_config(RetryConfig {
max_retries: 5,
initial_delay: Duration::from_secs(2),
max_delay: Duration::from_secs(30),
});
```
### Best Practices for Rate Limits
1. **Enable Caching**: Use `EmbeddingOptions::default_with_cache()`
```rust
let options = EmbeddingOptions::default_with_cache();
let embedding = provider.embed(text, &options).await?;
```
2. **Batch Requests Carefully**: HuggingFace Inference API processes requests sequentially
```rust
// This makes N API calls sequentially
let texts = vec!["text1", "text2", "text3"];
let result = provider.embed_batch(&texts, &options).await?;
```
3. **Use PRO Account for Production**: Free tier is suitable for development only
4. **Monitor Rate Limits**: Check response headers
```rust
// Future enhancement - rate limit monitoring
let stats = provider.rate_limit_stats();
println!("Remaining: {}/{}", stats.remaining, stats.limit);
```
## Authentication
### Environment Variables
The provider checks for API keys in this order:
1. `HUGGINGFACE_API_KEY`
2. `HF_TOKEN` (alternative name)
```bash
export HUGGINGFACE_API_KEY="hf_xxxxxxxxxxxxxxxxxxxx"
```
### Getting an API Token
1. Go to [HuggingFace Settings](https://huggingface.co/settings/tokens)
2. Click "New token"
3. Select "Read" access (sufficient for Inference API)
4. Copy the token starting with `hf_`
## Usage Examples
### Basic Usage
```rust
use stratum_embeddings::{HuggingFaceProvider, EmbeddingOptions};
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Using predefined model
let provider = HuggingFaceProvider::bge_small()?;
let options = EmbeddingOptions::default_with_cache();
let embedding = provider.embed("Hello world", &options).await?;
println!("Dimensions: {}", embedding.len()); // 384
Ok(())
}
```
### With EmbeddingService (Recommended)
```rust
use std::time::Duration;
use stratum_embeddings::{
HuggingFaceProvider, EmbeddingService, MemoryCache, EmbeddingOptions
};
let provider = HuggingFaceProvider::bge_small()?;
let cache = MemoryCache::new(1000, Duration::from_secs(3600));
let service = EmbeddingService::new(provider)
.with_cache(cache);
let options = EmbeddingOptions::default_with_cache();
let embedding = service.embed("Cached embeddings", &options).await?;
```
### Semantic Similarity Search
```rust
use stratum_embeddings::{HuggingFaceProvider, EmbeddingOptions, cosine_similarity};
let provider = HuggingFaceProvider::bge_small()?;
let options = EmbeddingOptions {
normalize: true, // Important for cosine similarity
truncate: true,
use_cache: true,
};
let query = "machine learning";
let doc1 = "deep learning and neural networks";
let doc2 = "cooking recipes";
let query_emb = provider.embed(query, &options).await?;
let doc1_emb = provider.embed(doc1, &options).await?;
let doc2_emb = provider.embed(doc2, &options).await?;
let sim1 = cosine_similarity(&query_emb, &doc1_emb);
let sim2 = cosine_similarity(&query_emb, &doc2_emb);
println!("Similarity with doc1: {:.4}", sim1); // ~0.85
println!("Similarity with doc2: {:.4}", sim2); // ~0.15
```
### Custom Model
```rust
use stratum_embeddings::{HuggingFaceProvider, HuggingFaceModel};
let api_key = std::env::var("HUGGINGFACE_API_KEY")?;
let model = HuggingFaceModel::Custom(
"intfloat/multilingual-e5-large".to_string(),
1024, // Specify dimensions
);
let provider = HuggingFaceProvider::new(api_key, model)?;
```
## Error Handling
### Common Errors
| Error | Cause | Solution |
|-------|-------|----------|
| `ConfigError: API key is empty` | Missing credentials | Set `HUGGINGFACE_API_KEY` |
| `ApiError: HTTP 401` | Invalid API token | Check token validity |
| `ApiError: HTTP 429` | Rate limit exceeded | Wait or upgrade tier |
| `ApiError: HTTP 503` | Model loading | Retry after ~20s |
| `DimensionMismatch` | Wrong model dimensions | Update `Custom` model dims |
### Retry Example
```rust
use tokio::time::sleep;
use std::time::Duration;
let mut retries = 0;
let max_retries = 3;
loop {
match provider.embed(text, &options).await {
Ok(embedding) => break Ok(embedding),
Err(e) if e.to_string().contains("429") && retries < max_retries => {
retries += 1;
let delay = Duration::from_secs(2u64.pow(retries));
eprintln!("Rate limited, retrying in {:?}...", delay);
sleep(delay).await;
}
Err(e) => break Err(e),
}
}
```
## Performance Characteristics
### Latency
| Operation | Latency | Notes |
|-----------|---------|-------|
| **Single embed** | 200-500ms | Depends on model size and region |
| **Batch (N items)** | N × 200-500ms | Sequential processing |
| **Cache hit** | <1ms | In-memory lookup |
| **Cold start** | +5-20s | First request loads model |
### Throughput
| Tier | Max RPS | Daily Limit |
|------|---------|-------------|
| Free | ~0.8 | 30,000 |
| PRO | ~2.8 | 100,000 |
**With Caching** (80% hit rate):
- Free tier: ~4 effective RPS
- PRO tier: ~14 effective RPS
## Cost Comparison
| Provider | Cost/1M Tokens | Free Tier | Notes |
|----------|----------------|-----------|-------|
| **HuggingFace** | $0.00 | 30k req/day | Free for public models |
| OpenAI | $0.02-0.13 | $5 credit | Pay per token |
| Cohere | $0.10 | 100 req/month | Limited free tier |
| Voyage | $0.12 | None | No free tier |
## Limitations
1. **No True Batching**: Inference API processes one request at a time
2. **Cold Starts**: Models need ~20s to load on first request
3. **Rate Limits**: Free tier suitable for development only
4. **Regional Latency**: Single region (US/EU), no edge locations
5. **Model Loading**: Popular models cached, custom models may be slow
## Advanced Configuration
### Model Loading Timeout
```rust
// Future enhancement
let provider = HuggingFaceProvider::new(api_key, model)?
.with_timeout(Duration::from_secs(120)); // Wait longer for cold starts
```
### Dedicated Inference Endpoints
For production workloads, consider [Dedicated Endpoints](https://huggingface.co/inference-endpoints):
- True batch processing
- Guaranteed uptime
- No rate limits
- Custom regions
- ~$60-500/month
## Migration Guide
### From vapora Custom Implementation
**Before**:
```rust
let hf = HuggingFaceEmbedding::new(api_key, "BAAI/bge-small-en-v1.5".to_string());
let embedding = hf.embed(text).await?;
```
**After**:
```rust
let provider = HuggingFaceProvider::bge_small()?;
let options = EmbeddingOptions::default_with_cache();
let embedding = provider.embed(text, &options).await?;
```
### From OpenAI
```rust
// OpenAI (paid)
let provider = OpenAiProvider::new(api_key, OpenAiModel::TextEmbedding3Small)?;
// HuggingFace (free, similar quality)
let provider = HuggingFaceProvider::bge_small()?;
```
## Running the Example
```bash
export HUGGINGFACE_API_KEY="hf_xxxxxxxxxxxxxxxxxxxx"
cargo run --example huggingface_usage \
--features huggingface-provider
```
## References
- [HuggingFace Inference API Docs](https://huggingface.co/docs/api-inference/index)
- [BGE Embedding Models](https://huggingface.co/BAAI)
- [Sentence Transformers](https://www.sbert.net/)
- [Rate Limits Documentation](https://huggingface.co/docs/api-inference/rate-limits)

View File

@ -0,0 +1,50 @@
use std::time::Duration;
use stratum_embeddings::{EmbeddingOptions, EmbeddingService, FastEmbedProvider, MemoryCache};
use tracing::info;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing_subscriber::fmt::init();
info!("Initializing FastEmbed provider...");
let provider = FastEmbedProvider::small()?;
let cache = MemoryCache::new(1000, Duration::from_secs(300));
let service = EmbeddingService::new(provider).with_cache(cache);
info!("Service ready: {:?}", service.provider_info());
let options = EmbeddingOptions::default_with_cache();
info!("Embedding single text...");
let text = "Stratum embeddings is a unified embedding library";
let embedding = service.embed(text, &options).await?;
info!("Generated embedding with {} dimensions", embedding.len());
info!("Embedding same text again (should be cached)...");
let embedding2 = service.embed(text, &options).await?;
assert_eq!(embedding, embedding2);
info!("Cache hit confirmed!");
info!("Embedding batch of texts...");
let texts = vec![
"Rust is a systems programming language".to_string(),
"Knowledge graphs connect concepts".to_string(),
"Vector databases enable semantic search".to_string(),
];
let result = service.embed_batch(texts, &options).await?;
info!(
"Batch complete: {} embeddings generated",
result.embeddings.len()
);
info!("Model: {}, Dimensions: {}", result.model, result.dimensions);
info!("Cached count: {}", result.cached_count);
info!("Cache size: {}", service.cache_size());
Ok(())
}

View File

@ -0,0 +1,44 @@
use std::{sync::Arc, time::Duration};
use stratum_embeddings::{
EmbeddingOptions, EmbeddingService, FastEmbedProvider, MemoryCache, OllamaProvider,
};
use tracing::{info, warn};
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing_subscriber::fmt::init();
info!("Setting up primary provider (Ollama)...");
let primary = OllamaProvider::default_model()?;
info!("Setting up fallback provider (FastEmbed)...");
let fallback =
Arc::new(FastEmbedProvider::small()?) as Arc<dyn stratum_embeddings::EmbeddingProvider>;
let cache = MemoryCache::new(1000, Duration::from_secs(300));
let service = EmbeddingService::new(primary)
.with_cache(cache)
.with_fallback(fallback);
let options = EmbeddingOptions::default_with_cache();
info!("Checking if Ollama is available...");
if service.is_ready().await {
info!("Ollama is available, using as primary");
} else {
warn!("Ollama not available, will fall back to FastEmbed");
}
info!("Embedding text (will use available provider)...");
let text = "This demonstrates fallback strategy in action";
let embedding = service.embed(text, &options).await?;
info!(
"Successfully generated embedding with {} dimensions",
embedding.len()
);
info!("Cache size: {}", service.cache_size());
Ok(())
}

View File

@ -0,0 +1,125 @@
use std::time::Duration;
use stratum_embeddings::{
EmbeddingOptions, HuggingFaceModel, HuggingFaceProvider, MemoryCache,
};
use tracing::info;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing_subscriber::fmt::init();
info!("=== HuggingFace Embedding Provider Demo ===");
// Example 1: Using predefined model (bge-small)
info!("\n1. Using predefined BGE-small model (384 dimensions)");
let provider = HuggingFaceProvider::bge_small()?;
let options = EmbeddingOptions::default_with_cache();
let text = "HuggingFace provides free inference API for embedding models";
let embedding = provider.embed(text, &options).await?;
info!(
"Generated embedding with {} dimensions from BGE-small",
embedding.len()
);
info!("First 5 values: {:?}", &embedding[..5]);
// Example 2: Using different model size
info!("\n2. Using BGE-base model (768 dimensions)");
let provider = HuggingFaceProvider::bge_base()?;
let embedding = provider.embed(text, &options).await?;
info!(
"Generated embedding with {} dimensions from BGE-base",
embedding.len()
);
// Example 3: Using custom model
info!("\n3. Using custom model");
let api_key = std::env::var("HUGGINGFACE_API_KEY")
.or_else(|_| std::env::var("HF_TOKEN"))
.expect("Set HUGGINGFACE_API_KEY or HF_TOKEN");
let custom_model = HuggingFaceModel::Custom(
"sentence-transformers/paraphrase-MiniLM-L6-v2".to_string(),
384,
);
let provider = HuggingFaceProvider::new(api_key, custom_model)?;
let embedding = provider.embed(text, &options).await?;
info!(
"Custom model embedding: {} dimensions",
embedding.len()
);
// Example 4: Batch embeddings (sequential requests to HF API)
info!("\n4. Batch embedding (sequential API calls)");
let provider = HuggingFaceProvider::all_minilm()?;
let texts = vec![
"First document about embeddings",
"Second document about transformers",
"Third document about NLP",
];
let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
let result = provider.embed_batch(&text_refs, &options).await?;
info!("Embedded {} texts", result.embeddings.len());
for (i, emb) in result.embeddings.iter().enumerate() {
info!(" Text {}: {} dimensions", i + 1, emb.len());
}
// Example 5: Using with cache
info!("\n5. Demonstrating cache effectiveness");
let cache = MemoryCache::new(1000, Duration::from_secs(300));
let service = stratum_embeddings::EmbeddingService::new(
HuggingFaceProvider::bge_small()?
).with_cache(cache);
let cached_options = EmbeddingOptions::default_with_cache();
// First call - cache miss
let start = std::time::Instant::now();
let _ = service.embed(text, &cached_options).await?;
let first_duration = start.elapsed();
info!("First call (cache miss): {:?}", first_duration);
// Second call - cache hit
let start = std::time::Instant::now();
let _ = service.embed(text, &cached_options).await?;
let second_duration = start.elapsed();
info!("Second call (cache hit): {:?}", second_duration);
info!("Speedup: {:.2}x", first_duration.as_secs_f64() / second_duration.as_secs_f64());
info!("Cache size: {}", service.cache_size());
// Example 6: Normalized embeddings for similarity search
info!("\n6. Normalized embeddings for similarity");
let provider = HuggingFaceProvider::bge_small()?;
let normalize_options = EmbeddingOptions {
normalize: true,
truncate: true,
use_cache: true,
};
let query = "machine learning embeddings";
let doc1 = "neural network embeddings for NLP";
let doc2 = "cooking recipes and ingredients";
let query_emb = provider.embed(query, &normalize_options).await?;
let doc1_emb = provider.embed(doc1, &normalize_options).await?;
let doc2_emb = provider.embed(doc2, &normalize_options).await?;
let sim1 = stratum_embeddings::cosine_similarity(&query_emb, &doc1_emb);
let sim2 = stratum_embeddings::cosine_similarity(&query_emb, &doc2_emb);
info!("Query: '{}'", query);
info!("Similarity with doc1 ('{}'): {:.4}", doc1, sim1);
info!("Similarity with doc2 ('{}'): {:.4}", doc2, sim2);
info!("Most similar: {}", if sim1 > sim2 { "doc1" } else { "doc2" });
info!("\n=== Demo Complete ===");
Ok(())
}

View File

@ -0,0 +1,67 @@
use std::time::Duration;
use stratum_embeddings::{
EmbeddingOptions, EmbeddingService, FastEmbedProvider, LanceDbStore, MemoryCache, VectorStore,
VectorStoreConfig,
};
use tempfile::tempdir;
use tracing::info;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing_subscriber::fmt::init();
info!("Initializing embedding service...");
let provider = FastEmbedProvider::small()?;
let cache = MemoryCache::new(1000, Duration::from_secs(300));
let service = EmbeddingService::new(provider).with_cache(cache);
let dir = tempdir()?;
let db_path = dir.path().to_str().unwrap();
info!("Creating LanceDB store at: {}", db_path);
let config = VectorStoreConfig::new(384);
let store = LanceDbStore::new(db_path, "embeddings", config).await?;
let documents = vec![
(
"doc1",
"Rust provides memory safety without garbage collection",
),
("doc2", "Knowledge graphs represent structured information"),
("doc3", "Vector databases enable semantic similarity search"),
("doc4", "Machine learning models learn from data patterns"),
("doc5", "Embeddings capture semantic meaning in vectors"),
];
info!("Embedding and storing {} documents...", documents.len());
let options = EmbeddingOptions::default_with_cache();
for (id, text) in &documents {
let embedding = service.embed(text, &options).await?;
let metadata = serde_json::json!({
"text": text,
"source": "demo"
});
store.upsert(id, &embedding, metadata).await?;
}
info!("Documents stored successfully");
info!("Performing semantic search...");
let query = "How do databases support similarity matching?";
let query_embedding = service.embed(query, &options).await?;
let results = store.search(&query_embedding, 3, None).await?;
info!("Search results for: '{}'", query);
for (i, result) in results.iter().enumerate() {
let text = result.metadata["text"].as_str().unwrap_or("N/A");
info!(" {}. [score: {:.4}] {}", i + 1, result.score, text);
}
let count = store.count().await?;
info!("Total documents in store: {}", count);
Ok(())
}

View File

@ -0,0 +1,66 @@
use std::time::Duration;
use stratum_embeddings::{
EmbeddingOptions, EmbeddingService, FastEmbedProvider, MemoryCache, SurrealDbStore,
VectorStore, VectorStoreConfig,
};
use tracing::info;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing_subscriber::fmt::init();
info!("Initializing embedding service...");
let provider = FastEmbedProvider::small()?;
let cache = MemoryCache::new(1000, Duration::from_secs(300));
let service = EmbeddingService::new(provider).with_cache(cache);
info!("Creating SurrealDB in-memory store...");
let config = VectorStoreConfig::new(384);
let store = SurrealDbStore::new_memory("concepts", config).await?;
let concepts = vec![
("ownership", "Rust's ownership system prevents memory leaks"),
(
"borrowing",
"Borrowing allows references without ownership transfer",
),
("lifetimes", "Lifetimes ensure references remain valid"),
("traits", "Traits define shared behavior across types"),
("generics", "Generics enable code reuse with type safety"),
];
info!("Embedding and storing {} concepts...", concepts.len());
let options = EmbeddingOptions::default_with_cache();
for (id, description) in &concepts {
let embedding = service.embed(description, &options).await?;
let metadata = serde_json::json!({
"concept": id,
"description": description,
"language": "rust"
});
store.upsert(id, &embedding, metadata).await?;
}
info!("Concepts stored successfully");
info!("Performing knowledge graph search...");
let query = "How does Rust manage memory?";
let query_embedding = service.embed(query, &options).await?;
let results = store.search(&query_embedding, 3, None).await?;
info!("Most relevant concepts for: '{}'", query);
for (i, result) in results.iter().enumerate() {
let concept = result.metadata["concept"].as_str().unwrap_or("N/A");
let description = result.metadata["description"].as_str().unwrap_or("N/A");
info!(" {}. {} [score: {:.4}]", i + 1, concept, result.score);
info!(" {}", description);
}
let count = store.count().await?;
info!("Total concepts in graph: {}", count);
Ok(())
}

View File

@ -0,0 +1,312 @@
use std::sync::Arc;
use futures::stream::{self, StreamExt};
use tracing::{debug, info};
use crate::{
cache::{cache_key, EmbeddingCache},
error::EmbeddingError,
traits::{Embedding, EmbeddingOptions, EmbeddingProvider, EmbeddingResult},
};
pub struct BatchProcessor<P: EmbeddingProvider, C: EmbeddingCache> {
provider: Arc<P>,
cache: Option<Arc<C>>,
max_concurrent: usize,
}
impl<P: EmbeddingProvider, C: EmbeddingCache> BatchProcessor<P, C> {
pub fn new(provider: Arc<P>, cache: Option<Arc<C>>) -> Self {
Self {
provider,
cache,
max_concurrent: 10,
}
}
pub fn with_concurrency(mut self, max_concurrent: usize) -> Self {
self.max_concurrent = max_concurrent;
self
}
pub async fn process_batch(
&self,
texts: Vec<String>,
options: &EmbeddingOptions,
) -> Result<EmbeddingResult, EmbeddingError> {
if texts.is_empty() {
return Err(EmbeddingError::InvalidInput(
"Texts cannot be empty".to_string(),
));
}
let provider_batch_size = self.provider.max_batch_size();
let cache_enabled = options.use_cache && self.cache.is_some();
let mut all_embeddings = Vec::with_capacity(texts.len());
let mut total_tokens = 0u32;
let mut cached_count = 0usize;
for chunk in texts.chunks(provider_batch_size) {
let (cache_hits, cache_misses) = if cache_enabled {
self.check_cache(chunk, self.provider.name(), self.provider.model())
.await
} else {
(vec![None; chunk.len()], (0..chunk.len()).collect())
};
let mut chunk_embeddings = cache_hits;
cached_count += chunk_embeddings.iter().filter(|e| e.is_some()).count();
if !cache_misses.is_empty() {
let texts_to_embed: Vec<&str> = cache_misses
.iter()
.map(|&idx| chunk[idx].as_str())
.collect();
debug!(
"Embedding {} texts (cached: {}, new: {})",
chunk.len(),
cached_count,
texts_to_embed.len()
);
let result = self.provider.embed_batch(&texts_to_embed, options).await?;
if let Some(tokens) = result.total_tokens {
total_tokens += tokens;
}
if let Some(cache) = cache_enabled.then_some(self.cache.as_ref()).flatten() {
let cache_items = Self::build_cache_items(
self.provider.name(),
self.provider.model(),
chunk,
&cache_misses,
&result.embeddings,
);
cache.insert_batch(cache_items).await;
}
for (miss_idx, embedding) in cache_misses.iter().zip(result.embeddings.into_iter())
{
chunk_embeddings[*miss_idx] = Some(embedding);
}
}
all_embeddings.extend(
chunk_embeddings
.into_iter()
.map(|e| e.expect("Missing embedding")),
);
}
info!(
"Batch complete: {} embeddings ({} cached, {} new)",
texts.len(),
cached_count,
texts.len() - cached_count
);
Ok(EmbeddingResult {
embeddings: all_embeddings,
model: self.provider.model().to_string(),
dimensions: self.provider.dimensions(),
total_tokens: if total_tokens > 0 {
Some(total_tokens)
} else {
None
},
cached_count,
})
}
pub async fn process_stream(
&self,
texts: Vec<String>,
options: &EmbeddingOptions,
) -> Result<Vec<Embedding>, EmbeddingError> {
let provider = Arc::clone(&self.provider);
let cache = self.cache.clone();
let provider_name = provider.name().to_string();
let provider_model = provider.model().to_string();
let opts = options.clone();
let embeddings: Vec<Embedding> = stream::iter(texts)
.map(move |text| {
let provider = Arc::clone(&provider);
let cache = cache.clone();
let provider_name = provider_name.clone();
let provider_model = provider_model.clone();
let opts = opts.clone();
Self::embed_with_cache(provider, cache, text, provider_name, provider_model, opts)
})
.buffer_unordered(self.max_concurrent)
.collect::<Vec<Result<Embedding, EmbeddingError>>>()
.await
.into_iter()
.collect::<Result<Vec<_>, _>>()?;
Ok(embeddings)
}
fn build_cache_items(
provider_name: &str,
provider_model: &str,
chunk: &[String],
cache_misses: &[usize],
embeddings: &[Embedding],
) -> Vec<(String, Embedding)> {
cache_misses
.iter()
.zip(embeddings.iter())
.map(|(&idx, emb)| {
(
cache_key(provider_name, provider_model, &chunk[idx]),
emb.clone(),
)
})
.collect()
}
async fn embed_with_cache(
provider: Arc<P>,
cache: Option<Arc<C>>,
text: String,
provider_name: String,
provider_model: String,
opts: EmbeddingOptions,
) -> Result<Embedding, EmbeddingError> {
let key = cache_key(&provider_name, &provider_model, &text);
if opts.use_cache {
if let Some(cached) = Self::try_get_cached(&cache, &key).await {
return Ok(cached);
}
}
let embedding = provider.embed(&text, &opts).await?;
if opts.use_cache {
Self::cache_insert(&cache, &key, embedding.clone()).await;
}
Ok(embedding)
}
async fn try_get_cached(cache: &Option<Arc<C>>, key: &str) -> Option<Embedding> {
match cache {
Some(c) => c.get(key).await,
None => None,
}
}
async fn cache_insert(cache: &Option<Arc<C>>, key: &str, embedding: Embedding) {
if let Some(c) = cache {
c.insert(key, embedding).await;
}
}
async fn check_cache(
&self,
texts: &[String],
provider_name: &str,
model_name: &str,
) -> (Vec<Option<Embedding>>, Vec<usize>) {
let cache = match &self.cache {
Some(c) => c,
None => return (vec![None; texts.len()], (0..texts.len()).collect()),
};
let keys: Vec<String> = texts
.iter()
.map(|text| cache_key(provider_name, model_name, text))
.collect();
let cached = cache.get_batch(&keys).await;
let misses: Vec<usize> = cached
.iter()
.enumerate()
.filter(|(_, e)| e.is_none())
.map(|(i, _)| i)
.collect();
(cached, misses)
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use super::*;
use crate::{cache::MemoryCache, providers::FastEmbedProvider};
#[tokio::test]
async fn test_batch_processor_no_cache() {
let provider = Arc::new(FastEmbedProvider::small().expect("Failed to init"));
let processor: BatchProcessor<_, MemoryCache> = BatchProcessor::new(provider, None);
let texts = vec!["Hello world".to_string(), "Goodbye world".to_string()];
let options = EmbeddingOptions::no_cache();
let result = processor
.process_batch(texts, &options)
.await
.expect("Failed to process");
assert_eq!(result.embeddings.len(), 2);
assert_eq!(result.cached_count, 0);
}
#[tokio::test]
async fn test_batch_processor_with_cache() {
let provider = Arc::new(FastEmbedProvider::small().expect("Failed to init"));
let cache = Arc::new(MemoryCache::new(100, Duration::from_secs(60)));
let processor = BatchProcessor::new(provider, Some(cache));
let texts = vec!["Hello world".to_string(), "Goodbye world".to_string()];
let options = EmbeddingOptions::default_with_cache();
let result1 = processor
.process_batch(texts.clone(), &options)
.await
.expect("Failed first batch");
assert_eq!(result1.embeddings.len(), 2);
assert_eq!(result1.cached_count, 0);
let result2 = processor
.process_batch(texts, &options)
.await
.expect("Failed second batch");
assert_eq!(result2.embeddings.len(), 2);
assert_eq!(result2.cached_count, 2);
assert_eq!(result1.embeddings, result2.embeddings);
}
#[tokio::test]
async fn test_batch_processor_stream() {
let provider = Arc::new(FastEmbedProvider::small().expect("Failed to init"));
let cache = Arc::new(MemoryCache::with_defaults());
let processor = BatchProcessor::new(provider, Some(cache)).with_concurrency(2);
let texts = vec![
"Text 1".to_string(),
"Text 2".to_string(),
"Text 3".to_string(),
"Text 4".to_string(),
];
let options = EmbeddingOptions::default_with_cache();
let embeddings = processor
.process_stream(texts, &options)
.await
.expect("Failed stream");
assert_eq!(embeddings.len(), 4);
}
}

View File

@ -0,0 +1,167 @@
#[cfg(feature = "memory-cache")]
use std::time::Duration;
#[cfg(feature = "memory-cache")]
use async_trait::async_trait;
#[cfg(feature = "memory-cache")]
use moka::future::Cache;
#[cfg(feature = "memory-cache")]
use crate::{cache::EmbeddingCache, traits::Embedding};
pub struct MemoryCache {
cache: Cache<String, Embedding>,
}
impl MemoryCache {
pub fn new(max_capacity: u64, ttl: Duration) -> Self {
let cache = Cache::builder()
.max_capacity(max_capacity)
.time_to_live(ttl)
.build();
Self { cache }
}
pub fn with_defaults() -> Self {
Self::new(10_000, Duration::from_secs(3600))
}
pub fn unlimited(ttl: Duration) -> Self {
let cache = Cache::builder().time_to_live(ttl).build();
Self { cache }
}
}
impl Default for MemoryCache {
fn default() -> Self {
Self::with_defaults()
}
}
#[async_trait]
impl EmbeddingCache for MemoryCache {
async fn get(&self, key: &str) -> Option<Embedding> {
self.cache.get(key).await
}
async fn insert(&self, key: &str, embedding: Embedding) {
self.cache.insert(key.to_string(), embedding).await;
self.cache.run_pending_tasks().await;
}
async fn get_batch(&self, keys: &[String]) -> Vec<Option<Embedding>> {
let mut results = Vec::with_capacity(keys.len());
for key in keys {
results.push(self.cache.get(key).await);
}
results
}
async fn insert_batch(&self, items: Vec<(String, Embedding)>) {
for (key, embedding) in items {
self.cache.insert(key, embedding).await;
}
self.cache.run_pending_tasks().await;
}
async fn invalidate(&self, key: &str) {
self.cache.invalidate(key).await;
self.cache.run_pending_tasks().await;
}
async fn clear(&self) {
self.cache.invalidate_all();
self.cache.run_pending_tasks().await;
}
fn size(&self) -> usize {
self.cache.entry_count() as usize
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_memory_cache_basic() {
let cache = MemoryCache::with_defaults();
let embedding = vec![1.0, 2.0, 3.0];
cache.insert("test_key", embedding.clone()).await;
let retrieved = cache.get("test_key").await;
assert_eq!(retrieved, Some(embedding));
let missing = cache.get("missing_key").await;
assert_eq!(missing, None);
}
#[tokio::test]
async fn test_memory_cache_batch() {
let cache = MemoryCache::with_defaults();
let items = vec![
("key1".to_string(), vec![1.0, 2.0]),
("key2".to_string(), vec![3.0, 4.0]),
("key3".to_string(), vec![5.0, 6.0]),
];
cache.insert_batch(items.clone()).await;
let keys = vec![
"key1".to_string(),
"key2".to_string(),
"missing".to_string(),
];
let results = cache.get_batch(&keys).await;
assert_eq!(results.len(), 3);
assert_eq!(results[0], Some(vec![1.0, 2.0]));
assert_eq!(results[1], Some(vec![3.0, 4.0]));
assert_eq!(results[2], None);
}
#[tokio::test]
async fn test_memory_cache_invalidate() {
let cache = MemoryCache::with_defaults();
cache.insert("key1", vec![1.0, 2.0]).await;
cache.insert("key2", vec![3.0, 4.0]).await;
assert!(cache.get("key1").await.is_some());
assert!(cache.get("key2").await.is_some());
cache.invalidate("key1").await;
assert!(cache.get("key1").await.is_none());
assert!(cache.get("key2").await.is_some());
}
#[tokio::test]
async fn test_memory_cache_clear() {
let cache = MemoryCache::with_defaults();
cache.insert("key1", vec![1.0]).await;
cache.insert("key2", vec![2.0]).await;
cache.clear().await;
assert!(cache.get("key1").await.is_none());
assert!(cache.get("key2").await.is_none());
}
#[tokio::test]
async fn test_memory_cache_ttl() {
let cache = MemoryCache::new(1000, Duration::from_millis(100));
cache.insert("key1", vec![1.0, 2.0]).await;
assert!(cache.get("key1").await.is_some());
tokio::time::sleep(Duration::from_millis(150)).await;
assert!(cache.get("key1").await.is_none());
}
}

View File

@ -0,0 +1,47 @@
#[cfg(feature = "memory-cache")]
pub mod memory;
#[cfg(feature = "persistent-cache")]
pub mod persistent;
use async_trait::async_trait;
#[cfg(feature = "memory-cache")]
pub use memory::MemoryCache;
#[cfg(feature = "persistent-cache")]
pub use persistent::PersistentCache;
use crate::traits::Embedding;
#[async_trait]
pub trait EmbeddingCache: Send + Sync {
async fn get(&self, key: &str) -> Option<Embedding>;
async fn insert(&self, key: &str, embedding: Embedding);
async fn get_batch(&self, keys: &[String]) -> Vec<Option<Embedding>>;
async fn insert_batch(&self, items: Vec<(String, Embedding)>);
async fn invalidate(&self, key: &str);
async fn clear(&self);
fn size(&self) -> usize;
}
pub fn cache_key(provider: &str, model: &str, text: &str) -> String {
use xxhash_rust::xxh3::xxh3_64;
let hash = xxh3_64(format!("{}:{}:{}", provider, model, text).as_bytes());
format!("{}:{}:{:x}", provider, model, hash)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_key_consistency() {
let key1 = cache_key("fastembed", "bge-small", "hello world");
let key2 = cache_key("fastembed", "bge-small", "hello world");
assert_eq!(key1, key2);
let key3 = cache_key("fastembed", "bge-small", "hello world!");
assert_ne!(key1, key3);
let key4 = cache_key("openai", "bge-small", "hello world");
assert_ne!(key1, key4);
}
}

View File

@ -0,0 +1,152 @@
#[cfg(feature = "persistent-cache")]
use std::path::Path;
#[cfg(feature = "persistent-cache")]
use async_trait::async_trait;
#[cfg(feature = "persistent-cache")]
use sled::Db;
#[cfg(feature = "persistent-cache")]
use crate::{cache::EmbeddingCache, error::EmbeddingError, traits::Embedding};
pub struct PersistentCache {
db: Db,
}
impl PersistentCache {
pub fn new<P: AsRef<Path>>(path: P) -> Result<Self, EmbeddingError> {
let db = sled::open(path)
.map_err(|e| EmbeddingError::CacheError(format!("Failed to open sled db: {}", e)))?;
Ok(Self { db })
}
pub fn in_memory() -> Result<Self, EmbeddingError> {
let db = sled::Config::new().temporary(true).open().map_err(|e| {
EmbeddingError::CacheError(format!("Failed to create in-memory sled db: {}", e))
})?;
Ok(Self { db })
}
fn serialize_embedding(embedding: &Embedding) -> Result<Vec<u8>, EmbeddingError> {
serde_json::to_vec(embedding)
.map_err(|e| EmbeddingError::SerializationError(format!("Embedding serialize: {}", e)))
}
fn deserialize_embedding(data: &[u8]) -> Result<Embedding, EmbeddingError> {
serde_json::from_slice(data).map_err(|e| {
EmbeddingError::SerializationError(format!("Embedding deserialize: {}", e))
})
}
}
#[async_trait]
impl EmbeddingCache for PersistentCache {
async fn get(&self, key: &str) -> Option<Embedding> {
self.db
.get(key)
.ok()
.flatten()
.and_then(|bytes| Self::deserialize_embedding(&bytes).ok())
}
async fn insert(&self, key: &str, embedding: Embedding) {
if let Ok(bytes) = Self::serialize_embedding(&embedding) {
let _ = self.db.insert(key, bytes);
let _ = self.db.flush();
}
}
async fn get_batch(&self, keys: &[String]) -> Vec<Option<Embedding>> {
keys.iter()
.map(|key| {
self.db
.get(key)
.ok()
.flatten()
.and_then(|bytes| Self::deserialize_embedding(&bytes).ok())
})
.collect()
}
async fn insert_batch(&self, items: Vec<(String, Embedding)>) {
for (key, embedding) in items {
if let Ok(bytes) = Self::serialize_embedding(&embedding) {
let _ = self.db.insert(key, bytes);
}
}
let _ = self.db.flush();
}
async fn invalidate(&self, key: &str) {
let _ = self.db.remove(key);
let _ = self.db.flush();
}
async fn clear(&self) {
let _ = self.db.clear();
let _ = self.db.flush();
}
fn size(&self) -> usize {
self.db.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_persistent_cache_in_memory() {
let cache = PersistentCache::in_memory().expect("Failed to create cache");
let embedding = vec![1.0, 2.0, 3.0];
cache.insert("test_key", embedding.clone()).await;
let retrieved = cache.get("test_key").await;
assert_eq!(retrieved, Some(embedding));
}
#[tokio::test]
async fn test_persistent_cache_batch() {
let cache = PersistentCache::in_memory().expect("Failed to create cache");
let items = vec![
("key1".to_string(), vec![1.0, 2.0]),
("key2".to_string(), vec![3.0, 4.0]),
];
cache.insert_batch(items).await;
let keys = vec!["key1".to_string(), "key2".to_string()];
let results = cache.get_batch(&keys).await;
assert_eq!(results[0], Some(vec![1.0, 2.0]));
assert_eq!(results[1], Some(vec![3.0, 4.0]));
}
#[tokio::test]
async fn test_persistent_cache_invalidate() {
let cache = PersistentCache::in_memory().expect("Failed to create cache");
cache.insert("key1", vec![1.0]).await;
assert!(cache.get("key1").await.is_some());
cache.invalidate("key1").await;
assert!(cache.get("key1").await.is_none());
}
#[tokio::test]
async fn test_persistent_cache_clear() {
let cache = PersistentCache::in_memory().expect("Failed to create cache");
cache.insert("key1", vec![1.0]).await;
cache.insert("key2", vec![2.0]).await;
assert_eq!(cache.size(), 2);
cache.clear().await;
assert_eq!(cache.size(), 0);
}
}

View File

@ -0,0 +1,153 @@
use std::time::Duration;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingConfig {
pub provider: ProviderConfig,
pub cache: CacheConfig,
#[serde(default)]
pub batch: BatchConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum ProviderConfig {
FastEmbed {
model: String,
},
OpenAI {
api_key: String,
model: String,
base_url: Option<String>,
},
Ollama {
model: String,
base_url: Option<String>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheConfig {
pub enabled: bool,
pub max_capacity: u64,
#[serde(with = "humantime_serde")]
pub ttl: Duration,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
enabled: true,
max_capacity: 10_000,
ttl: Duration::from_secs(3600),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchConfig {
pub max_concurrent: usize,
pub chunk_size: Option<usize>,
}
impl Default for BatchConfig {
fn default() -> Self {
Self {
max_concurrent: 10,
chunk_size: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorStoreSettings {
pub dimensions: usize,
pub metric: String,
}
impl EmbeddingConfig {
pub fn from_env(provider_type: &str) -> Result<Self, crate::error::EmbeddingError> {
let provider = match provider_type {
"fastembed" => {
let model =
std::env::var("FASTEMBED_MODEL").unwrap_or_else(|_| "bge-small-en".to_string());
ProviderConfig::FastEmbed { model }
}
"openai" => {
let api_key = std::env::var("OPENAI_API_KEY").map_err(|_| {
crate::error::EmbeddingError::ConfigError("OPENAI_API_KEY not set".to_string())
})?;
let model = std::env::var("OPENAI_MODEL")
.unwrap_or_else(|_| "text-embedding-3-small".to_string());
let base_url = std::env::var("OPENAI_BASE_URL").ok();
ProviderConfig::OpenAI {
api_key,
model,
base_url,
}
}
"ollama" => {
let model = std::env::var("OLLAMA_MODEL")
.unwrap_or_else(|_| "nomic-embed-text".to_string());
let base_url = std::env::var("OLLAMA_BASE_URL").ok();
ProviderConfig::Ollama { model, base_url }
}
_ => {
return Err(crate::error::EmbeddingError::ConfigError(format!(
"Unknown provider type: {}",
provider_type
)))
}
};
Ok(Self {
provider,
cache: CacheConfig::default(),
batch: BatchConfig::default(),
})
}
pub fn with_cache(mut self, config: CacheConfig) -> Self {
self.cache = config;
self
}
pub fn with_batch(mut self, config: BatchConfig) -> Self {
self.batch = config;
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_serialization() {
let config = EmbeddingConfig {
provider: ProviderConfig::FastEmbed {
model: "bge-small-en".to_string(),
},
cache: CacheConfig::default(),
batch: BatchConfig::default(),
};
let json = serde_json::to_string(&config).expect("Failed to serialize");
let deserialized: EmbeddingConfig =
serde_json::from_str(&json).expect("Failed to deserialize");
match deserialized.provider {
ProviderConfig::FastEmbed { model } => assert_eq!(model, "bge-small-en"),
_ => panic!("Wrong provider type"),
}
}
#[test]
fn test_cache_config_defaults() {
let config = CacheConfig::default();
assert!(config.enabled);
assert_eq!(config.max_capacity, 10_000);
assert_eq!(config.ttl, Duration::from_secs(3600));
}
}

View File

@ -0,0 +1,76 @@
use thiserror::Error;
#[derive(Error, Debug)]
pub enum EmbeddingError {
#[error("Provider initialization failed: {0}")]
Initialization(String),
#[error("Provider not available: {0}")]
ProviderUnavailable(String),
#[error("API request failed: {0}")]
ApiError(String),
#[error("Invalid input: {0}")]
InvalidInput(String),
#[error("Dimension mismatch: expected {expected}, got {actual}")]
DimensionMismatch { expected: usize, actual: usize },
#[error("Batch size {size} exceeds maximum {max}")]
BatchSizeExceeded { size: usize, max: usize },
#[error("Cache error: {0}")]
CacheError(String),
#[error("Store error: {0}")]
StoreError(String),
#[error("Serialization error: {0}")]
SerializationError(String),
#[error("Rate limit exceeded: {0}")]
RateLimitExceeded(String),
#[error("Timeout: {0}")]
Timeout(String),
#[error("Configuration error: {0}")]
ConfigError(String),
#[error("IO error: {0}")]
IoError(String),
#[error("HTTP error: {0}")]
HttpError(String),
#[error(transparent)]
Other(#[from] Box<dyn std::error::Error + Send + Sync>),
}
impl From<std::io::Error> for EmbeddingError {
fn from(err: std::io::Error) -> Self {
Self::IoError(err.to_string())
}
}
impl From<serde_json::Error> for EmbeddingError {
fn from(err: serde_json::Error) -> Self {
Self::SerializationError(err.to_string())
}
}
#[cfg(feature = "reqwest")]
impl From<reqwest::Error> for EmbeddingError {
fn from(err: reqwest::Error) -> Self {
if err.is_timeout() {
Self::Timeout(err.to_string())
} else if err.is_status() {
Self::HttpError(format!("HTTP {}: {}", err.status().unwrap(), err))
} else {
Self::ApiError(err.to_string())
}
}
}
pub type Result<T> = std::result::Result<T, EmbeddingError>;

View File

@ -0,0 +1,34 @@
pub mod batch;
pub mod cache;
pub mod config;
pub mod error;
pub mod metrics;
pub mod providers;
pub mod service;
pub mod store;
pub mod traits;
#[cfg(feature = "memory-cache")]
pub use cache::MemoryCache;
#[cfg(feature = "persistent-cache")]
pub use cache::PersistentCache;
pub use config::{BatchConfig, CacheConfig, EmbeddingConfig, ProviderConfig};
pub use error::{EmbeddingError, Result};
#[cfg(feature = "cohere-provider")]
pub use providers::cohere::{CohereModel, CohereProvider};
#[cfg(feature = "fastembed-provider")]
pub use providers::fastembed::{FastEmbedModel, FastEmbedProvider};
#[cfg(feature = "huggingface-provider")]
pub use providers::huggingface::{HuggingFaceModel, HuggingFaceProvider};
#[cfg(feature = "ollama-provider")]
pub use providers::ollama::{OllamaModel, OllamaProvider};
#[cfg(feature = "openai-provider")]
pub use providers::openai::{OpenAiModel, OpenAiProvider};
#[cfg(feature = "voyage-provider")]
pub use providers::voyage::{VoyageModel, VoyageProvider};
pub use service::EmbeddingService;
pub use store::*;
pub use traits::{
cosine_similarity, euclidean_distance, normalize_embedding, Embedding, EmbeddingOptions,
EmbeddingProvider, EmbeddingResult, ProviderInfo,
};

View File

@ -0,0 +1,195 @@
#[cfg(feature = "metrics")]
use std::sync::OnceLock;
#[cfg(feature = "metrics")]
use prometheus::{
register_histogram_vec, register_int_counter_vec, HistogramOpts, HistogramVec, IntCounterVec,
Opts,
};
#[cfg(feature = "metrics")]
static EMBEDDING_REQUESTS: OnceLock<IntCounterVec> = OnceLock::new();
#[cfg(feature = "metrics")]
static EMBEDDING_ERRORS: OnceLock<IntCounterVec> = OnceLock::new();
#[cfg(feature = "metrics")]
static EMBEDDING_DURATION: OnceLock<HistogramVec> = OnceLock::new();
#[cfg(feature = "metrics")]
static CACHE_HITS: OnceLock<IntCounterVec> = OnceLock::new();
#[cfg(feature = "metrics")]
static CACHE_MISSES: OnceLock<IntCounterVec> = OnceLock::new();
#[cfg(feature = "metrics")]
static TOKENS_PROCESSED: OnceLock<IntCounterVec> = OnceLock::new();
#[cfg(feature = "metrics")]
pub fn init_metrics() -> Result<(), Box<dyn std::error::Error>> {
EMBEDDING_REQUESTS.get_or_init(|| {
register_int_counter_vec!(
Opts::new(
"embedding_requests_total",
"Total number of embedding requests"
),
&["provider", "model"]
)
.expect("Failed to register embedding_requests_total")
});
EMBEDDING_ERRORS.get_or_init(|| {
register_int_counter_vec!(
Opts::new("embedding_errors_total", "Total number of embedding errors"),
&["provider", "model", "error_type"]
)
.expect("Failed to register embedding_errors_total")
});
EMBEDDING_DURATION.get_or_init(|| {
register_histogram_vec!(
HistogramOpts::new("embedding_duration_seconds", "Embedding request duration")
.buckets(vec![0.001, 0.01, 0.1, 0.5, 1.0, 5.0, 10.0]),
&["provider", "model"]
)
.expect("Failed to register embedding_duration_seconds")
});
CACHE_HITS.get_or_init(|| {
register_int_counter_vec!(
Opts::new("embedding_cache_hits_total", "Total cache hits"),
&["provider", "model"]
)
.expect("Failed to register embedding_cache_hits_total")
});
CACHE_MISSES.get_or_init(|| {
register_int_counter_vec!(
Opts::new("embedding_cache_misses_total", "Total cache misses"),
&["provider", "model"]
)
.expect("Failed to register embedding_cache_misses_total")
});
TOKENS_PROCESSED.get_or_init(|| {
register_int_counter_vec!(
Opts::new("embedding_tokens_processed_total", "Total tokens processed"),
&["provider", "model"]
)
.expect("Failed to register embedding_tokens_processed_total")
});
Ok(())
}
#[cfg(feature = "metrics")]
pub fn record_request(provider: &str, model: &str) {
if let Some(counter) = EMBEDDING_REQUESTS.get() {
counter.with_label_values(&[provider, model]).inc();
}
}
#[cfg(feature = "metrics")]
pub fn record_error(provider: &str, model: &str, error_type: &str) {
if let Some(counter) = EMBEDDING_ERRORS.get() {
counter
.with_label_values(&[provider, model, error_type])
.inc();
}
}
#[cfg(feature = "metrics")]
pub fn record_duration(provider: &str, model: &str, duration_secs: f64) {
if let Some(histogram) = EMBEDDING_DURATION.get() {
histogram
.with_label_values(&[provider, model])
.observe(duration_secs);
}
}
#[cfg(feature = "metrics")]
pub fn record_cache_hit(provider: &str, model: &str) {
if let Some(counter) = CACHE_HITS.get() {
counter.with_label_values(&[provider, model]).inc();
}
}
#[cfg(feature = "metrics")]
pub fn record_cache_miss(provider: &str, model: &str) {
if let Some(counter) = CACHE_MISSES.get() {
counter.with_label_values(&[provider, model]).inc();
}
}
#[cfg(feature = "metrics")]
pub fn record_tokens(provider: &str, model: &str, tokens: u64) {
if let Some(counter) = TOKENS_PROCESSED.get() {
counter.with_label_values(&[provider, model]).inc_by(tokens);
}
}
#[cfg(feature = "metrics")]
pub struct MetricsGuard {
provider: String,
model: String,
start: std::time::Instant,
}
#[cfg(feature = "metrics")]
impl MetricsGuard {
pub fn new(provider: &str, model: &str) -> Self {
record_request(provider, model);
Self {
provider: provider.to_string(),
model: model.to_string(),
start: std::time::Instant::now(),
}
}
pub fn record_success(self, tokens: Option<u32>) {
let duration = self.start.elapsed().as_secs_f64();
record_duration(&self.provider, &self.model, duration);
if let Some(token_count) = tokens {
record_tokens(&self.provider, &self.model, token_count as u64);
}
}
pub fn record_error(self, error_type: &str) {
record_error(&self.provider, &self.model, error_type);
let duration = self.start.elapsed().as_secs_f64();
record_duration(&self.provider, &self.model, duration);
}
}
#[cfg(not(feature = "metrics"))]
pub fn init_metrics() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}
#[cfg(test)]
#[cfg(feature = "metrics")]
mod tests {
use super::*;
#[test]
fn test_metrics_initialization() {
let result = init_metrics();
assert!(result.is_ok());
}
#[test]
fn test_record_request() {
init_metrics().unwrap();
record_request("test-provider", "test-model");
}
#[test]
fn test_metrics_guard() {
init_metrics().unwrap();
let guard = MetricsGuard::new("test", "model");
guard.record_success(Some(100));
}
#[test]
fn test_cache_metrics() {
init_metrics().unwrap();
record_cache_hit("test", "model");
record_cache_miss("test", "model");
}
}

View File

@ -0,0 +1,251 @@
#[cfg(feature = "cohere-provider")]
use async_trait::async_trait;
#[cfg(feature = "cohere-provider")]
use reqwest::Client;
#[cfg(feature = "cohere-provider")]
use serde::{Deserialize, Serialize};
#[cfg(feature = "cohere-provider")]
use crate::{
error::EmbeddingError,
traits::{Embedding, EmbeddingOptions, EmbeddingProvider, EmbeddingResult},
};
#[derive(Debug, Clone, Copy, PartialEq, Default)]
pub enum CohereModel {
#[default]
EmbedEnglishV3,
EmbedMultilingualV3,
EmbedEnglishLightV3,
EmbedMultilingualLightV3,
EmbedEnglishV2,
EmbedMultilingualV2,
}
impl CohereModel {
pub fn model_name(&self) -> &'static str {
match self {
Self::EmbedEnglishV3 => "embed-english-v3.0",
Self::EmbedMultilingualV3 => "embed-multilingual-v3.0",
Self::EmbedEnglishLightV3 => "embed-english-light-v3.0",
Self::EmbedMultilingualLightV3 => "embed-multilingual-light-v3.0",
Self::EmbedEnglishV2 => "embed-english-v2.0",
Self::EmbedMultilingualV2 => "embed-multilingual-v2.0",
}
}
pub fn dimensions(&self) -> usize {
match self {
Self::EmbedEnglishV3 | Self::EmbedMultilingualV3 => 1024,
Self::EmbedEnglishLightV3 | Self::EmbedMultilingualLightV3 => 384,
Self::EmbedEnglishV2 | Self::EmbedMultilingualV2 => 4096,
}
}
}
#[derive(Debug, Serialize)]
struct CohereEmbedRequest {
model: String,
texts: Vec<String>,
input_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
truncate: Option<String>,
}
#[derive(Debug, Deserialize)]
struct CohereEmbedResponse {
embeddings: Vec<Vec<f32>>,
meta: CohereMeta,
}
#[derive(Debug, Deserialize)]
struct CohereMeta {
billed_units: Option<BilledUnits>,
}
#[derive(Debug, Deserialize)]
struct BilledUnits {
input_tokens: Option<u32>,
}
pub struct CohereProvider {
client: Client,
api_key: String,
model: CohereModel,
base_url: String,
}
impl CohereProvider {
pub fn new(api_key: String, model: CohereModel) -> Self {
Self {
client: Client::new(),
api_key,
model,
base_url: "https://api.cohere.ai/v1".to_string(),
}
}
pub fn with_base_url(mut self, base_url: String) -> Self {
self.base_url = base_url;
self
}
pub fn embed_english_v3(api_key: String) -> Self {
Self::new(api_key, CohereModel::EmbedEnglishV3)
}
pub fn embed_multilingual_v3(api_key: String) -> Self {
Self::new(api_key, CohereModel::EmbedMultilingualV3)
}
async fn embed_batch_internal(
&self,
texts: &[&str],
_options: &EmbeddingOptions,
) -> Result<EmbeddingResult, EmbeddingError> {
let request = CohereEmbedRequest {
model: self.model.model_name().to_string(),
texts: texts.iter().map(|s| s.to_string()).collect(),
input_type: "search_document".to_string(),
truncate: Some("END".to_string()),
};
let response = self
.client
.post(format!("{}/embed", self.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&request)
.send()
.await
.map_err(|e| EmbeddingError::ApiError(format!("Cohere API request failed: {}", e)))?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
return Err(EmbeddingError::ApiError(format!(
"Cohere API error {}: {}",
status, error_text
)));
}
let result: CohereEmbedResponse = response.json().await.map_err(|e| {
EmbeddingError::ApiError(format!("Failed to parse Cohere response: {}", e))
})?;
let total_tokens = result.meta.billed_units.and_then(|b| b.input_tokens);
Ok(EmbeddingResult {
embeddings: result.embeddings,
model: self.model.model_name().to_string(),
dimensions: self.model.dimensions(),
total_tokens,
cached_count: 0,
})
}
}
#[async_trait]
impl EmbeddingProvider for CohereProvider {
fn name(&self) -> &str {
"cohere"
}
fn model(&self) -> &str {
self.model.model_name()
}
fn dimensions(&self) -> usize {
self.model.dimensions()
}
fn is_local(&self) -> bool {
false
}
fn max_tokens(&self) -> usize {
match self.model {
CohereModel::EmbedEnglishV3
| CohereModel::EmbedMultilingualV3
| CohereModel::EmbedEnglishLightV3
| CohereModel::EmbedMultilingualLightV3 => 512,
CohereModel::EmbedEnglishV2 | CohereModel::EmbedMultilingualV2 => 512,
}
}
fn max_batch_size(&self) -> usize {
96
}
fn cost_per_1m_tokens(&self) -> f64 {
match self.model {
CohereModel::EmbedEnglishV3 | CohereModel::EmbedMultilingualV3 => 0.10,
CohereModel::EmbedEnglishLightV3 | CohereModel::EmbedMultilingualLightV3 => 0.10,
CohereModel::EmbedEnglishV2 | CohereModel::EmbedMultilingualV2 => 0.10,
}
}
async fn is_available(&self) -> bool {
!self.api_key.is_empty()
}
async fn embed(
&self,
text: &str,
options: &EmbeddingOptions,
) -> Result<Embedding, EmbeddingError> {
let result = self.embed_batch_internal(&[text], options).await?;
result
.embeddings
.into_iter()
.next()
.ok_or_else(|| EmbeddingError::ApiError("No embedding returned".to_string()))
}
async fn embed_batch(
&self,
texts: &[&str],
options: &EmbeddingOptions,
) -> Result<EmbeddingResult, EmbeddingError> {
if texts.is_empty() {
return Err(EmbeddingError::InvalidInput(
"Cannot embed empty text list".to_string(),
));
}
if texts.len() > self.max_batch_size() {
return Err(EmbeddingError::InvalidInput(format!(
"Batch size {} exceeds maximum {}",
texts.len(),
self.max_batch_size()
)));
}
self.embed_batch_internal(texts, options).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cohere_model_names() {
assert_eq!(
CohereModel::EmbedEnglishV3.model_name(),
"embed-english-v3.0"
);
assert_eq!(CohereModel::EmbedMultilingualV3.dimensions(), 1024);
assert_eq!(CohereModel::EmbedEnglishLightV3.dimensions(), 384);
assert_eq!(CohereModel::EmbedEnglishV2.dimensions(), 4096);
}
#[tokio::test]
async fn test_cohere_provider_creation() {
let provider = CohereProvider::new("test-key".to_string(), CohereModel::EmbedEnglishV3);
assert_eq!(provider.name(), "cohere");
assert_eq!(provider.model(), "embed-english-v3.0");
assert_eq!(provider.dimensions(), 1024);
assert_eq!(provider.max_batch_size(), 96);
}
}

View File

@ -0,0 +1,247 @@
#[cfg(feature = "fastembed-provider")]
use std::sync::{Arc, Mutex};
#[cfg(feature = "fastembed-provider")]
use async_trait::async_trait;
#[cfg(feature = "fastembed-provider")]
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
#[cfg(feature = "fastembed-provider")]
use crate::{
error::EmbeddingError,
traits::{
normalize_embedding, Embedding, EmbeddingOptions, EmbeddingProvider, EmbeddingResult,
},
};
#[derive(Debug, Clone, Copy, PartialEq, Default)]
pub enum FastEmbedModel {
#[default]
BgeSmallEn,
BgeBaseEn,
BgeLargeEn,
AllMiniLmL6V2,
MultilingualE5Small,
MultilingualE5Base,
}
impl FastEmbedModel {
pub fn dimensions(&self) -> usize {
match self {
Self::BgeSmallEn | Self::AllMiniLmL6V2 | Self::MultilingualE5Small => 384,
Self::BgeBaseEn | Self::MultilingualE5Base => 768,
Self::BgeLargeEn => 1024,
}
}
pub fn model_name(&self) -> &'static str {
match self {
Self::BgeSmallEn => "BAAI/bge-small-en-v1.5",
Self::BgeBaseEn => "BAAI/bge-base-en-v1.5",
Self::BgeLargeEn => "BAAI/bge-large-en-v1.5",
Self::AllMiniLmL6V2 => "sentence-transformers/all-MiniLM-L6-v2",
Self::MultilingualE5Small => "intfloat/multilingual-e5-small",
Self::MultilingualE5Base => "intfloat/multilingual-e5-base",
}
}
fn to_fastembed_model(self) -> EmbeddingModel {
match self {
Self::BgeSmallEn => EmbeddingModel::BGESmallENV15,
Self::BgeBaseEn => EmbeddingModel::BGEBaseENV15,
Self::BgeLargeEn => EmbeddingModel::BGELargeENV15,
Self::AllMiniLmL6V2 => EmbeddingModel::AllMiniLML6V2,
Self::MultilingualE5Small => EmbeddingModel::MultilingualE5Small,
Self::MultilingualE5Base => EmbeddingModel::MultilingualE5Base,
}
}
}
pub struct FastEmbedProvider {
model: Arc<Mutex<TextEmbedding>>,
model_type: FastEmbedModel,
}
impl FastEmbedProvider {
pub fn new(model_type: FastEmbedModel) -> Result<Self, EmbeddingError> {
let options =
InitOptions::new(model_type.to_fastembed_model()).with_show_download_progress(true);
let model = TextEmbedding::try_new(options)
.map_err(|e| EmbeddingError::Initialization(e.to_string()))?;
Ok(Self {
model: Arc::new(Mutex::new(model)),
model_type,
})
}
pub fn default_model() -> Result<Self, EmbeddingError> {
Self::new(FastEmbedModel::default())
}
pub fn small() -> Result<Self, EmbeddingError> {
Self::new(FastEmbedModel::BgeSmallEn)
}
pub fn base() -> Result<Self, EmbeddingError> {
Self::new(FastEmbedModel::BgeBaseEn)
}
pub fn large() -> Result<Self, EmbeddingError> {
Self::new(FastEmbedModel::BgeLargeEn)
}
pub fn multilingual() -> Result<Self, EmbeddingError> {
Self::new(FastEmbedModel::MultilingualE5Base)
}
}
#[async_trait]
impl EmbeddingProvider for FastEmbedProvider {
fn name(&self) -> &str {
"fastembed"
}
fn model(&self) -> &str {
self.model_type.model_name()
}
fn dimensions(&self) -> usize {
self.model_type.dimensions()
}
fn is_local(&self) -> bool {
true
}
fn max_tokens(&self) -> usize {
512
}
fn max_batch_size(&self) -> usize {
256
}
fn cost_per_1m_tokens(&self) -> f64 {
0.0
}
async fn is_available(&self) -> bool {
true
}
async fn embed(
&self,
text: &str,
options: &EmbeddingOptions,
) -> Result<Embedding, EmbeddingError> {
let text = text.to_string();
let model = Arc::clone(&self.model);
let embeddings = tokio::task::spawn_blocking(move || {
let mut model_guard = model.lock().expect("Failed to acquire model lock");
model_guard.embed(vec![text], None)
})
.await
.map_err(|e| EmbeddingError::ApiError(format!("Task join error: {}", e)))?
.map_err(|e| EmbeddingError::ApiError(e.to_string()))?;
let mut embedding = embeddings.into_iter().next().ok_or_else(|| {
EmbeddingError::ApiError("FastEmbed returned no embeddings".to_string())
})?;
if options.normalize {
normalize_embedding(&mut embedding);
}
Ok(embedding)
}
async fn embed_batch(
&self,
texts: &[&str],
options: &EmbeddingOptions,
) -> Result<EmbeddingResult, EmbeddingError> {
if texts.len() > self.max_batch_size() {
return Err(EmbeddingError::BatchSizeExceeded {
size: texts.len(),
max: self.max_batch_size(),
});
}
let texts_owned: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
let model = Arc::clone(&self.model);
let mut embeddings = tokio::task::spawn_blocking(move || {
let mut model_guard = model.lock().expect("Failed to acquire model lock");
model_guard.embed(texts_owned, None)
})
.await
.map_err(|e| EmbeddingError::ApiError(format!("Task join error: {}", e)))?
.map_err(|e| EmbeddingError::ApiError(e.to_string()))?;
if options.normalize {
embeddings.iter_mut().for_each(normalize_embedding);
}
Ok(EmbeddingResult {
embeddings,
model: self.model().to_string(),
dimensions: self.dimensions(),
total_tokens: None,
cached_count: 0,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_fastembed_provider() {
let provider = FastEmbedProvider::small().expect("Failed to initialize FastEmbed");
assert_eq!(provider.name(), "fastembed");
assert_eq!(provider.dimensions(), 384);
assert!(provider.is_local());
assert_eq!(provider.cost_per_1m_tokens(), 0.0);
let options = EmbeddingOptions::default_with_cache();
let embedding = provider
.embed("Hello world", &options)
.await
.expect("Failed to embed");
assert_eq!(embedding.len(), 384);
let magnitude: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(magnitude - 1.0).abs() < 0.01,
"Embedding should be normalized"
);
}
#[tokio::test]
async fn test_fastembed_batch() {
let provider = FastEmbedProvider::small().expect("Failed to initialize FastEmbed");
let texts = vec!["Hello world", "Goodbye world", "Machine learning"];
let options = EmbeddingOptions::default_with_cache();
let result = provider
.embed_batch(&texts, &options)
.await
.expect("Failed to embed batch");
assert_eq!(result.embeddings.len(), 3);
assert_eq!(result.dimensions, 384);
for embedding in &result.embeddings {
assert_eq!(embedding.len(), 384);
let magnitude: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((magnitude - 1.0).abs() < 0.01);
}
}
}

View File

@ -0,0 +1,344 @@
#[cfg(feature = "huggingface-provider")]
use std::time::Duration;
#[cfg(feature = "huggingface-provider")]
use async_trait::async_trait;
#[cfg(feature = "huggingface-provider")]
use reqwest::Client;
#[cfg(feature = "huggingface-provider")]
use serde::{Deserialize, Serialize};
#[cfg(feature = "huggingface-provider")]
use crate::{
error::EmbeddingError,
traits::{
normalize_embedding, Embedding, EmbeddingOptions, EmbeddingProvider, EmbeddingResult,
},
};
#[derive(Debug, Clone, PartialEq)]
pub enum HuggingFaceModel {
/// BAAI/bge-small-en-v1.5 - 384 dimensions, efficient general-purpose
BgeSmall,
/// BAAI/bge-base-en-v1.5 - 768 dimensions, balanced performance
BgeBase,
/// BAAI/bge-large-en-v1.5 - 1024 dimensions, high quality
BgeLarge,
/// sentence-transformers/all-MiniLM-L6-v2 - 384 dimensions, fast
AllMiniLm,
/// sentence-transformers/all-mpnet-base-v2 - 768 dimensions, strong baseline
AllMpnet,
/// Custom model with model ID and dimensions
Custom(String, usize),
}
impl Default for HuggingFaceModel {
fn default() -> Self {
Self::BgeSmall
}
}
impl HuggingFaceModel {
pub fn model_id(&self) -> &str {
match self {
Self::BgeSmall => "BAAI/bge-small-en-v1.5",
Self::BgeBase => "BAAI/bge-base-en-v1.5",
Self::BgeLarge => "BAAI/bge-large-en-v1.5",
Self::AllMiniLm => "sentence-transformers/all-MiniLM-L6-v2",
Self::AllMpnet => "sentence-transformers/all-mpnet-base-v2",
Self::Custom(id, _) => id.as_str(),
}
}
pub fn dimensions(&self) -> usize {
match self {
Self::BgeSmall => 384,
Self::BgeBase => 768,
Self::BgeLarge => 1024,
Self::AllMiniLm => 384,
Self::AllMpnet => 768,
Self::Custom(_, dims) => *dims,
}
}
pub fn from_model_id(id: &str, dimensions: Option<usize>) -> Self {
match id {
"BAAI/bge-small-en-v1.5" => Self::BgeSmall,
"BAAI/bge-base-en-v1.5" => Self::BgeBase,
"BAAI/bge-large-en-v1.5" => Self::BgeLarge,
"sentence-transformers/all-MiniLM-L6-v2" => Self::AllMiniLm,
"sentence-transformers/all-mpnet-base-v2" => Self::AllMpnet,
_ => Self::Custom(
id.to_string(),
dimensions.unwrap_or(384), // Default to 384 if unknown
),
}
}
}
#[cfg(feature = "huggingface-provider")]
pub struct HuggingFaceProvider {
client: Client,
api_key: String,
model: HuggingFaceModel,
}
#[cfg(feature = "huggingface-provider")]
#[derive(Debug, Serialize)]
struct HFRequest {
inputs: String,
}
#[cfg(feature = "huggingface-provider")]
#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum HFResponse {
/// Single text embedding response
Single(Vec<f32>),
/// Batch embedding response
Multiple(Vec<Vec<f32>>),
}
#[cfg(feature = "huggingface-provider")]
impl HuggingFaceProvider {
const BASE_URL: &'static str = "https://api-inference.huggingface.co/pipeline/feature-extraction";
pub fn new(api_key: impl Into<String>, model: HuggingFaceModel) -> Result<Self, EmbeddingError> {
let api_key = api_key.into();
if api_key.is_empty() {
return Err(EmbeddingError::ConfigError(
"HuggingFace API key is empty".to_string(),
));
}
let client = Client::builder()
.timeout(Duration::from_secs(120))
.build()
.map_err(|e| EmbeddingError::Initialization(e.to_string()))?;
Ok(Self {
client,
api_key,
model,
})
}
pub fn from_env(model: HuggingFaceModel) -> Result<Self, EmbeddingError> {
let api_key = std::env::var("HUGGINGFACE_API_KEY")
.or_else(|_| std::env::var("HF_TOKEN"))
.map_err(|_| {
EmbeddingError::ConfigError(
"HUGGINGFACE_API_KEY or HF_TOKEN environment variable not set".to_string(),
)
})?;
Self::new(api_key, model)
}
pub fn bge_small() -> Result<Self, EmbeddingError> {
Self::from_env(HuggingFaceModel::BgeSmall)
}
pub fn bge_base() -> Result<Self, EmbeddingError> {
Self::from_env(HuggingFaceModel::BgeBase)
}
pub fn all_minilm() -> Result<Self, EmbeddingError> {
Self::from_env(HuggingFaceModel::AllMiniLm)
}
}
#[cfg(feature = "huggingface-provider")]
#[async_trait]
impl EmbeddingProvider for HuggingFaceProvider {
fn name(&self) -> &str {
"huggingface"
}
fn model(&self) -> &str {
self.model.model_id()
}
fn dimensions(&self) -> usize {
self.model.dimensions()
}
fn is_local(&self) -> bool {
false
}
fn max_tokens(&self) -> usize {
// HuggingFace doesn't specify a hard limit, but most models handle ~512 tokens
512
}
fn max_batch_size(&self) -> usize {
// HuggingFace Inference API doesn't support batch requests
// Each request is individual
1
}
fn cost_per_1m_tokens(&self) -> f64 {
// HuggingFace Inference API is free for public models
// For dedicated endpoints, costs vary
0.0
}
async fn is_available(&self) -> bool {
// HuggingFace Inference API is always available (with rate limits)
true
}
async fn embed(
&self,
text: &str,
options: &EmbeddingOptions,
) -> Result<Embedding, EmbeddingError> {
if text.is_empty() {
return Err(EmbeddingError::InvalidInput(
"Text cannot be empty".to_string(),
));
}
let url = format!("{}/{}", Self::BASE_URL, self.model.model_id());
let request = HFRequest {
inputs: text.to_string(),
};
let response = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&request)
.send()
.await
.map_err(|e| {
EmbeddingError::ApiError(format!("HuggingFace API request failed: {}", e))
})?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
return Err(EmbeddingError::ApiError(format!(
"HuggingFace API error {}: {}",
status, error_text
)));
}
let hf_response: HFResponse = response.json().await.map_err(|e| {
EmbeddingError::ApiError(format!("Failed to parse HuggingFace response: {}", e))
})?;
let mut embedding = match hf_response {
HFResponse::Single(emb) => emb,
HFResponse::Multiple(embs) => {
if embs.is_empty() {
return Err(EmbeddingError::ApiError(
"Empty embeddings response from HuggingFace".to_string(),
));
}
embs[0].clone()
}
};
// Validate dimensions
if embedding.len() != self.dimensions() {
return Err(EmbeddingError::DimensionMismatch {
expected: self.dimensions(),
actual: embedding.len(),
});
}
// Normalize if requested
if options.normalize {
normalize_embedding(&mut embedding);
}
Ok(embedding)
}
async fn embed_batch(
&self,
texts: &[&str],
options: &EmbeddingOptions,
) -> Result<EmbeddingResult, EmbeddingError> {
if texts.is_empty() {
return Err(EmbeddingError::InvalidInput(
"Texts cannot be empty".to_string(),
));
}
// HuggingFace Inference API doesn't support true batch requests
// We need to send individual requests
let mut embeddings = Vec::with_capacity(texts.len());
for text in texts {
let embedding = self.embed(text, options).await?;
embeddings.push(embedding);
}
Ok(EmbeddingResult {
embeddings,
model: self.model().to_string(),
dimensions: self.dimensions(),
total_tokens: None,
cached_count: 0,
})
}
}
#[cfg(all(test, feature = "huggingface-provider"))]
mod tests {
use super::*;
#[test]
fn test_model_id_mapping() {
assert_eq!(
HuggingFaceModel::BgeSmall.model_id(),
"BAAI/bge-small-en-v1.5"
);
assert_eq!(HuggingFaceModel::BgeSmall.dimensions(), 384);
assert_eq!(
HuggingFaceModel::BgeBase.model_id(),
"BAAI/bge-base-en-v1.5"
);
assert_eq!(HuggingFaceModel::BgeBase.dimensions(), 768);
}
#[test]
fn test_custom_model() {
let custom = HuggingFaceModel::Custom("my-model".to_string(), 512);
assert_eq!(custom.model_id(), "my-model");
assert_eq!(custom.dimensions(), 512);
}
#[test]
fn test_from_model_id() {
let model = HuggingFaceModel::from_model_id("BAAI/bge-small-en-v1.5", None);
assert_eq!(model, HuggingFaceModel::BgeSmall);
let custom = HuggingFaceModel::from_model_id("unknown-model", Some(256));
assert!(matches!(custom, HuggingFaceModel::Custom(_, 256)));
}
#[test]
fn test_provider_creation() {
let provider = HuggingFaceProvider::new("test-key", HuggingFaceModel::BgeSmall);
assert!(provider.is_ok());
let provider = provider.unwrap();
assert_eq!(provider.name(), "huggingface");
assert_eq!(provider.model(), "BAAI/bge-small-en-v1.5");
assert_eq!(provider.dimensions(), 384);
}
#[test]
fn test_empty_api_key() {
let provider = HuggingFaceProvider::new("", HuggingFaceModel::BgeSmall);
assert!(provider.is_err());
assert!(matches!(
provider.unwrap_err(),
EmbeddingError::ConfigError(_)
));
}
}

View File

@ -0,0 +1,25 @@
#[cfg(feature = "cohere-provider")]
pub mod cohere;
#[cfg(feature = "fastembed-provider")]
pub mod fastembed;
#[cfg(feature = "huggingface-provider")]
pub mod huggingface;
#[cfg(feature = "ollama-provider")]
pub mod ollama;
#[cfg(feature = "openai-provider")]
pub mod openai;
#[cfg(feature = "voyage-provider")]
pub mod voyage;
#[cfg(feature = "cohere-provider")]
pub use cohere::{CohereModel, CohereProvider};
#[cfg(feature = "fastembed-provider")]
pub use fastembed::{FastEmbedModel, FastEmbedProvider};
#[cfg(feature = "huggingface-provider")]
pub use huggingface::{HuggingFaceModel, HuggingFaceProvider};
#[cfg(feature = "ollama-provider")]
pub use ollama::{OllamaModel, OllamaProvider};
#[cfg(feature = "openai-provider")]
pub use openai::{OpenAiModel, OpenAiProvider};
#[cfg(feature = "voyage-provider")]
pub use voyage::{VoyageModel, VoyageProvider};

View File

@ -0,0 +1,272 @@
#[cfg(feature = "ollama-provider")]
use std::time::Duration;
#[cfg(feature = "ollama-provider")]
use async_trait::async_trait;
#[cfg(feature = "ollama-provider")]
use reqwest::Client;
#[cfg(feature = "ollama-provider")]
use serde::{Deserialize, Serialize};
#[cfg(feature = "ollama-provider")]
use crate::{
error::EmbeddingError,
traits::{
normalize_embedding, Embedding, EmbeddingOptions, EmbeddingProvider, EmbeddingResult,
},
};
#[derive(Debug, Clone, PartialEq, Default)]
pub enum OllamaModel {
#[default]
NomicEmbed,
MxbaiEmbed,
AllMiniLm,
Custom(String, usize),
}
impl OllamaModel {
pub fn model_name(&self) -> &str {
match self {
Self::NomicEmbed => "nomic-embed-text",
Self::MxbaiEmbed => "mxbai-embed-large",
Self::AllMiniLm => "all-minilm",
Self::Custom(name, _) => name,
}
}
pub fn dimensions(&self) -> usize {
match self {
Self::NomicEmbed => 768,
Self::MxbaiEmbed => 1024,
Self::AllMiniLm => 384,
Self::Custom(_, dims) => *dims,
}
}
}
#[derive(Serialize)]
struct OllamaRequest {
model: String,
input: OllamaInput,
#[serde(skip_serializing_if = "Option::is_none")]
options: Option<serde_json::Value>,
}
#[derive(Serialize)]
#[serde(untagged)]
enum OllamaInput {
Single(String),
Batch(Vec<String>),
}
#[derive(Deserialize)]
struct OllamaResponse {
#[serde(default)]
embeddings: Vec<Vec<f32>>,
#[serde(default)]
embedding: Option<Vec<f32>>,
}
pub struct OllamaProvider {
client: Client,
model: OllamaModel,
base_url: String,
}
impl OllamaProvider {
pub fn new(model: OllamaModel) -> Result<Self, EmbeddingError> {
let client = Client::builder()
.timeout(Duration::from_secs(120))
.build()
.map_err(|e| EmbeddingError::Initialization(e.to_string()))?;
Ok(Self {
client,
model,
base_url: "http://localhost:11434".to_string(),
})
}
pub fn default_model() -> Result<Self, EmbeddingError> {
Self::new(OllamaModel::default())
}
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
self.base_url = base_url.into();
self
}
pub fn custom_model(
name: impl Into<String>,
dimensions: usize,
) -> Result<Self, EmbeddingError> {
Self::new(OllamaModel::Custom(name.into(), dimensions))
}
async fn call_api(&self, input: OllamaInput) -> Result<OllamaResponse, EmbeddingError> {
let request = OllamaRequest {
model: self.model.model_name().to_string(),
input,
options: None,
};
let response = self
.client
.post(format!("{}/api/embed", self.base_url))
.json(&request)
.send()
.await
.map_err(|e| {
if e.is_connect() {
EmbeddingError::ProviderUnavailable(format!(
"Cannot connect to Ollama at {}",
self.base_url
))
} else {
e.into()
}
})?;
if !response.status().is_success() {
let status = response.status();
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(EmbeddingError::ApiError(format!(
"Ollama API error {}: {}",
status, error_text
)));
}
let api_response: OllamaResponse = response.json().await?;
Ok(api_response)
}
}
#[async_trait]
impl EmbeddingProvider for OllamaProvider {
fn name(&self) -> &str {
"ollama"
}
fn model(&self) -> &str {
self.model.model_name()
}
fn dimensions(&self) -> usize {
self.model.dimensions()
}
fn is_local(&self) -> bool {
true
}
fn max_tokens(&self) -> usize {
2048
}
fn max_batch_size(&self) -> usize {
128
}
fn cost_per_1m_tokens(&self) -> f64 {
0.0
}
async fn is_available(&self) -> bool {
self.client
.get(format!("{}/api/tags", self.base_url))
.send()
.await
.map(|r| r.status().is_success())
.unwrap_or(false)
}
async fn embed(
&self,
text: &str,
options: &EmbeddingOptions,
) -> Result<Embedding, EmbeddingError> {
if text.is_empty() {
return Err(EmbeddingError::InvalidInput(
"Text cannot be empty".to_string(),
));
}
let response = self.call_api(OllamaInput::Single(text.to_string())).await?;
let mut embedding = if let Some(emb) = response.embedding {
emb
} else {
response.embeddings.into_iter().next().ok_or_else(|| {
EmbeddingError::ApiError("Ollama returned no embeddings".to_string())
})?
};
if options.normalize {
normalize_embedding(&mut embedding);
}
Ok(embedding)
}
async fn embed_batch(
&self,
texts: &[&str],
options: &EmbeddingOptions,
) -> Result<EmbeddingResult, EmbeddingError> {
if texts.is_empty() {
return Err(EmbeddingError::InvalidInput(
"Texts cannot be empty".to_string(),
));
}
if texts.len() > self.max_batch_size() {
return Err(EmbeddingError::BatchSizeExceeded {
size: texts.len(),
max: self.max_batch_size(),
});
}
let texts_owned: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
let response = self.call_api(OllamaInput::Batch(texts_owned)).await?;
let mut embeddings = response.embeddings;
if embeddings.is_empty() {
return Err(EmbeddingError::ApiError(
"Ollama returned no embeddings".to_string(),
));
}
if options.normalize {
embeddings.iter_mut().for_each(normalize_embedding);
}
Ok(EmbeddingResult {
embeddings,
model: self.model().to_string(),
dimensions: self.dimensions(),
total_tokens: None,
cached_count: 0,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ollama_model_metadata() {
assert_eq!(OllamaModel::NomicEmbed.dimensions(), 768);
assert_eq!(OllamaModel::MxbaiEmbed.dimensions(), 1024);
assert_eq!(OllamaModel::AllMiniLm.dimensions(), 384);
let custom = OllamaModel::Custom("my-model".to_string(), 512);
assert_eq!(custom.model_name(), "my-model");
assert_eq!(custom.dimensions(), 512);
}
}

View File

@ -0,0 +1,286 @@
#[cfg(feature = "openai-provider")]
use std::time::Duration;
#[cfg(feature = "openai-provider")]
use async_trait::async_trait;
#[cfg(feature = "openai-provider")]
use reqwest::Client;
#[cfg(feature = "openai-provider")]
use serde::{Deserialize, Serialize};
#[cfg(feature = "openai-provider")]
use crate::{
error::EmbeddingError,
traits::{
normalize_embedding, Embedding, EmbeddingOptions, EmbeddingProvider, EmbeddingResult,
},
};
#[derive(Debug, Clone, Copy, PartialEq, Default)]
pub enum OpenAiModel {
#[default]
TextEmbedding3Small,
TextEmbedding3Large,
TextEmbeddingAda002,
}
impl OpenAiModel {
pub fn model_name(&self) -> &'static str {
match self {
Self::TextEmbedding3Small => "text-embedding-3-small",
Self::TextEmbedding3Large => "text-embedding-3-large",
Self::TextEmbeddingAda002 => "text-embedding-ada-002",
}
}
pub fn dimensions(&self) -> usize {
match self {
Self::TextEmbedding3Small => 1536,
Self::TextEmbedding3Large => 3072,
Self::TextEmbeddingAda002 => 1536,
}
}
pub fn cost_per_1m(&self) -> f64 {
match self {
Self::TextEmbedding3Small => 0.02,
Self::TextEmbedding3Large => 0.13,
Self::TextEmbeddingAda002 => 0.10,
}
}
}
#[derive(Serialize)]
struct OpenAiRequest {
input: OpenAiInput,
model: String,
#[serde(skip_serializing_if = "Option::is_none")]
encoding_format: Option<String>,
}
#[derive(Serialize)]
#[serde(untagged)]
enum OpenAiInput {
Single(String),
Batch(Vec<String>),
}
#[derive(Deserialize)]
struct OpenAiResponse {
data: Vec<OpenAiEmbedding>,
usage: OpenAiUsage,
}
#[derive(Deserialize)]
struct OpenAiEmbedding {
embedding: Vec<f32>,
index: usize,
}
#[derive(Deserialize)]
struct OpenAiUsage {
total_tokens: u32,
}
pub struct OpenAiProvider {
client: Client,
api_key: String,
model: OpenAiModel,
base_url: String,
}
impl OpenAiProvider {
pub fn new(api_key: impl Into<String>, model: OpenAiModel) -> Result<Self, EmbeddingError> {
let api_key = api_key.into();
if api_key.is_empty() {
return Err(EmbeddingError::ConfigError(
"OpenAI API key is empty".to_string(),
));
}
let client = Client::builder()
.timeout(Duration::from_secs(60))
.build()
.map_err(|e| EmbeddingError::Initialization(e.to_string()))?;
Ok(Self {
client,
api_key,
model,
base_url: "https://api.openai.com/v1".to_string(),
})
}
pub fn from_env(model: OpenAiModel) -> Result<Self, EmbeddingError> {
let api_key = std::env::var("OPENAI_API_KEY")
.map_err(|_| EmbeddingError::ConfigError("OPENAI_API_KEY not set".to_string()))?;
Self::new(api_key, model)
}
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
self.base_url = base_url.into();
self
}
async fn call_api(&self, input: OpenAiInput) -> Result<OpenAiResponse, EmbeddingError> {
let request = OpenAiRequest {
input,
model: self.model.model_name().to_string(),
encoding_format: None,
};
let response = self
.client
.post(format!("{}/embeddings", self.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&request)
.send()
.await?;
if !response.status().is_success() {
let status = response.status();
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(EmbeddingError::ApiError(format!(
"OpenAI API error {}: {}",
status, error_text
)));
}
let api_response: OpenAiResponse = response.json().await?;
Ok(api_response)
}
}
#[async_trait]
impl EmbeddingProvider for OpenAiProvider {
fn name(&self) -> &str {
"openai"
}
fn model(&self) -> &str {
self.model.model_name()
}
fn dimensions(&self) -> usize {
self.model.dimensions()
}
fn is_local(&self) -> bool {
false
}
fn max_tokens(&self) -> usize {
8191
}
fn max_batch_size(&self) -> usize {
2048
}
fn cost_per_1m_tokens(&self) -> f64 {
self.model.cost_per_1m()
}
async fn is_available(&self) -> bool {
self.client
.get(format!("{}/models", self.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.send()
.await
.map(|r| r.status().is_success())
.unwrap_or(false)
}
async fn embed(
&self,
text: &str,
options: &EmbeddingOptions,
) -> Result<Embedding, EmbeddingError> {
if text.is_empty() {
return Err(EmbeddingError::InvalidInput(
"Text cannot be empty".to_string(),
));
}
let response = self.call_api(OpenAiInput::Single(text.to_string())).await?;
let mut embedding = response
.data
.into_iter()
.next()
.ok_or_else(|| EmbeddingError::ApiError("OpenAI returned no embeddings".to_string()))?
.embedding;
if options.normalize {
normalize_embedding(&mut embedding);
}
Ok(embedding)
}
async fn embed_batch(
&self,
texts: &[&str],
options: &EmbeddingOptions,
) -> Result<EmbeddingResult, EmbeddingError> {
if texts.is_empty() {
return Err(EmbeddingError::InvalidInput(
"Texts cannot be empty".to_string(),
));
}
if texts.len() > self.max_batch_size() {
return Err(EmbeddingError::BatchSizeExceeded {
size: texts.len(),
max: self.max_batch_size(),
});
}
let texts_owned: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
let response = self.call_api(OpenAiInput::Batch(texts_owned)).await?;
let mut embeddings_with_index: Vec<_> = response
.data
.into_iter()
.map(|e| (e.index, e.embedding))
.collect();
embeddings_with_index.sort_by_key(|(idx, _)| *idx);
let mut embeddings: Vec<Embedding> = embeddings_with_index
.into_iter()
.map(|(_, emb)| emb)
.collect();
if options.normalize {
embeddings.iter_mut().for_each(normalize_embedding);
}
Ok(EmbeddingResult {
embeddings,
model: self.model().to_string(),
dimensions: self.dimensions(),
total_tokens: Some(response.usage.total_tokens),
cached_count: 0,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_openai_model_metadata() {
assert_eq!(OpenAiModel::TextEmbedding3Small.dimensions(), 1536);
assert_eq!(OpenAiModel::TextEmbedding3Large.dimensions(), 3072);
assert_eq!(OpenAiModel::TextEmbeddingAda002.dimensions(), 1536);
assert_eq!(OpenAiModel::TextEmbedding3Small.cost_per_1m(), 0.02);
assert_eq!(OpenAiModel::TextEmbedding3Large.cost_per_1m(), 0.13);
}
}

View File

@ -0,0 +1,259 @@
#[cfg(feature = "voyage-provider")]
use async_trait::async_trait;
#[cfg(feature = "voyage-provider")]
use reqwest::Client;
#[cfg(feature = "voyage-provider")]
use serde::{Deserialize, Serialize};
#[cfg(feature = "voyage-provider")]
use crate::{
error::EmbeddingError,
traits::{Embedding, EmbeddingOptions, EmbeddingProvider, EmbeddingResult},
};
#[derive(Debug, Clone, Copy, PartialEq, Default)]
pub enum VoyageModel {
#[default]
Voyage2,
VoyageLarge2,
VoyageCode2,
VoyageLite02Instruct,
}
impl VoyageModel {
pub fn model_name(&self) -> &'static str {
match self {
Self::Voyage2 => "voyage-2",
Self::VoyageLarge2 => "voyage-large-2",
Self::VoyageCode2 => "voyage-code-2",
Self::VoyageLite02Instruct => "voyage-lite-02-instruct",
}
}
pub fn dimensions(&self) -> usize {
match self {
Self::Voyage2 => 1024,
Self::VoyageLarge2 => 1536,
Self::VoyageCode2 => 1536,
Self::VoyageLite02Instruct => 1024,
}
}
pub fn max_tokens(&self) -> usize {
match self {
Self::Voyage2 | Self::VoyageLarge2 | Self::VoyageCode2 => 16000,
Self::VoyageLite02Instruct => 4000,
}
}
}
#[derive(Debug, Serialize)]
struct VoyageEmbedRequest {
input: Vec<String>,
model: String,
#[serde(skip_serializing_if = "Option::is_none")]
input_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
truncation: Option<bool>,
}
#[derive(Debug, Deserialize)]
struct VoyageEmbedResponse {
data: Vec<VoyageEmbeddingData>,
usage: VoyageUsage,
}
#[derive(Debug, Deserialize)]
struct VoyageEmbeddingData {
embedding: Vec<f32>,
}
#[derive(Debug, Deserialize)]
struct VoyageUsage {
total_tokens: u32,
}
pub struct VoyageProvider {
client: Client,
api_key: String,
model: VoyageModel,
base_url: String,
}
impl VoyageProvider {
pub fn new(api_key: String, model: VoyageModel) -> Self {
Self {
client: Client::new(),
api_key,
model,
base_url: "https://api.voyageai.com/v1".to_string(),
}
}
pub fn with_base_url(mut self, base_url: String) -> Self {
self.base_url = base_url;
self
}
pub fn voyage_2(api_key: String) -> Self {
Self::new(api_key, VoyageModel::Voyage2)
}
pub fn voyage_large_2(api_key: String) -> Self {
Self::new(api_key, VoyageModel::VoyageLarge2)
}
pub fn voyage_code_2(api_key: String) -> Self {
Self::new(api_key, VoyageModel::VoyageCode2)
}
async fn embed_batch_internal(
&self,
texts: &[&str],
_options: &EmbeddingOptions,
) -> Result<EmbeddingResult, EmbeddingError> {
let request = VoyageEmbedRequest {
input: texts.iter().map(|s| s.to_string()).collect(),
model: self.model.model_name().to_string(),
input_type: Some("document".to_string()),
truncation: Some(true),
};
let response = self
.client
.post(format!("{}/embeddings", self.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&request)
.send()
.await
.map_err(|e| EmbeddingError::ApiError(format!("Voyage API request failed: {}", e)))?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
return Err(EmbeddingError::ApiError(format!(
"Voyage API error {}: {}",
status, error_text
)));
}
let result: VoyageEmbedResponse = response.json().await.map_err(|e| {
EmbeddingError::ApiError(format!("Failed to parse Voyage response: {}", e))
})?;
let embeddings: Vec<Embedding> = result.data.into_iter().map(|d| d.embedding).collect();
Ok(EmbeddingResult {
embeddings,
model: self.model.model_name().to_string(),
dimensions: self.model.dimensions(),
total_tokens: Some(result.usage.total_tokens),
cached_count: 0,
})
}
}
#[async_trait]
impl EmbeddingProvider for VoyageProvider {
fn name(&self) -> &str {
"voyage"
}
fn model(&self) -> &str {
self.model.model_name()
}
fn dimensions(&self) -> usize {
self.model.dimensions()
}
fn is_local(&self) -> bool {
false
}
fn max_tokens(&self) -> usize {
self.model.max_tokens()
}
fn max_batch_size(&self) -> usize {
128
}
fn cost_per_1m_tokens(&self) -> f64 {
match self.model {
VoyageModel::Voyage2 => 0.10,
VoyageModel::VoyageLarge2 => 0.12,
VoyageModel::VoyageCode2 => 0.12,
VoyageModel::VoyageLite02Instruct => 0.06,
}
}
async fn is_available(&self) -> bool {
!self.api_key.is_empty()
}
async fn embed(
&self,
text: &str,
options: &EmbeddingOptions,
) -> Result<Embedding, EmbeddingError> {
let result = self.embed_batch_internal(&[text], options).await?;
result
.embeddings
.into_iter()
.next()
.ok_or_else(|| EmbeddingError::ApiError("No embedding returned".to_string()))
}
async fn embed_batch(
&self,
texts: &[&str],
options: &EmbeddingOptions,
) -> Result<EmbeddingResult, EmbeddingError> {
if texts.is_empty() {
return Err(EmbeddingError::InvalidInput(
"Cannot embed empty text list".to_string(),
));
}
if texts.len() > self.max_batch_size() {
return Err(EmbeddingError::InvalidInput(format!(
"Batch size {} exceeds maximum {}",
texts.len(),
self.max_batch_size()
)));
}
self.embed_batch_internal(texts, options).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_voyage_model_names() {
assert_eq!(VoyageModel::Voyage2.model_name(), "voyage-2");
assert_eq!(VoyageModel::Voyage2.dimensions(), 1024);
assert_eq!(VoyageModel::VoyageLarge2.dimensions(), 1536);
assert_eq!(VoyageModel::VoyageCode2.dimensions(), 1536);
assert_eq!(VoyageModel::VoyageLite02Instruct.dimensions(), 1024);
}
#[test]
fn test_voyage_max_tokens() {
assert_eq!(VoyageModel::Voyage2.max_tokens(), 16000);
assert_eq!(VoyageModel::VoyageLite02Instruct.max_tokens(), 4000);
}
#[tokio::test]
async fn test_voyage_provider_creation() {
let provider = VoyageProvider::new("test-key".to_string(), VoyageModel::Voyage2);
assert_eq!(provider.name(), "voyage");
assert_eq!(provider.model(), "voyage-2");
assert_eq!(provider.dimensions(), 1024);
assert_eq!(provider.max_batch_size(), 128);
}
}

View File

@ -0,0 +1,283 @@
use std::sync::Arc;
use tracing::{debug, info, warn};
use crate::{
batch::BatchProcessor,
cache::EmbeddingCache,
error::EmbeddingError,
traits::{Embedding, EmbeddingOptions, EmbeddingProvider, EmbeddingResult, ProviderInfo},
};
pub struct EmbeddingService<P: EmbeddingProvider, C: EmbeddingCache> {
provider: Arc<P>,
cache: Option<Arc<C>>,
fallback_providers: Vec<Arc<dyn EmbeddingProvider>>,
batch_processor: BatchProcessor<P, C>,
}
impl<P: EmbeddingProvider + 'static, C: EmbeddingCache + 'static> EmbeddingService<P, C> {
pub fn new(provider: P) -> Self {
let provider = Arc::new(provider);
let batch_processor = BatchProcessor::new(Arc::clone(&provider), None);
Self {
provider,
cache: None,
fallback_providers: Vec::new(),
batch_processor,
}
}
pub fn with_cache(mut self, cache: C) -> Self {
let cache = Arc::new(cache);
self.cache = Some(Arc::clone(&cache));
self.batch_processor = BatchProcessor::new(Arc::clone(&self.provider), Some(cache));
self
}
pub fn with_fallback(mut self, fallback: Arc<dyn EmbeddingProvider>) -> Self {
self.fallback_providers.push(fallback);
self
}
pub fn with_batch_concurrency(mut self, max_concurrent: usize) -> Self {
self.batch_processor = self.batch_processor.with_concurrency(max_concurrent);
self
}
pub fn provider_info(&self) -> ProviderInfo {
ProviderInfo::from(self.provider.as_ref())
}
pub async fn is_ready(&self) -> bool {
self.provider.is_available().await
}
pub async fn embed(
&self,
text: &str,
options: &EmbeddingOptions,
) -> Result<Embedding, EmbeddingError> {
if text.is_empty() {
return Err(EmbeddingError::InvalidInput(
"Text cannot be empty".to_string(),
));
}
if options.use_cache {
if let Some(cache) = &self.cache {
let key =
crate::cache::cache_key(self.provider.name(), self.provider.model(), text);
if let Some(cached) = cache.get(&key).await {
debug!("Cache hit for text (len={})", text.len());
return Ok(cached);
}
}
}
let result = self.embed_with_fallback(text, options).await?;
if options.use_cache {
if let Some(cache) = &self.cache {
let key =
crate::cache::cache_key(self.provider.name(), self.provider.model(), text);
cache.insert(&key, result.clone()).await;
}
}
Ok(result)
}
pub async fn embed_batch(
&self,
texts: Vec<String>,
options: &EmbeddingOptions,
) -> Result<EmbeddingResult, EmbeddingError> {
if texts.is_empty() {
return Err(EmbeddingError::InvalidInput(
"Texts cannot be empty".to_string(),
));
}
info!("Processing batch of {} texts", texts.len());
self.batch_processor.process_batch(texts, options).await
}
pub async fn embed_stream(
&self,
texts: Vec<String>,
options: &EmbeddingOptions,
) -> Result<Vec<Embedding>, EmbeddingError> {
self.batch_processor.process_stream(texts, options).await
}
async fn embed_with_fallback(
&self,
text: &str,
options: &EmbeddingOptions,
) -> Result<Embedding, EmbeddingError> {
match self.provider.embed(text, options).await {
Ok(embedding) => Ok(embedding),
Err(e) => {
warn!("Primary provider failed: {}", e);
for (idx, fallback) in self.fallback_providers.iter().enumerate() {
debug!("Trying fallback provider {}", idx);
if !fallback.is_available().await {
warn!("Fallback provider {} not available", idx);
continue;
}
match fallback.embed(text, options).await {
Ok(embedding) => {
info!("Fallback provider {} succeeded", idx);
return Ok(embedding);
}
Err(fallback_err) => {
warn!("Fallback provider {} failed: {}", idx, fallback_err);
}
}
}
Err(e)
}
}
}
pub async fn invalidate_cache(&self, text: &str) -> Result<(), EmbeddingError> {
if let Some(cache) = &self.cache {
let key = crate::cache::cache_key(self.provider.name(), self.provider.model(), text);
cache.invalidate(&key).await;
Ok(())
} else {
Err(EmbeddingError::CacheError(
"No cache configured".to_string(),
))
}
}
pub async fn clear_cache(&self) -> Result<(), EmbeddingError> {
if let Some(cache) = &self.cache {
cache.clear().await;
Ok(())
} else {
Err(EmbeddingError::CacheError(
"No cache configured".to_string(),
))
}
}
pub fn cache_size(&self) -> usize {
self.cache.as_ref().map(|c| c.size()).unwrap_or(0)
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use super::*;
use crate::{cache::MemoryCache, providers::FastEmbedProvider};
#[tokio::test]
async fn test_service_basic() {
let provider = FastEmbedProvider::small().expect("Failed to init");
let service: EmbeddingService<_, MemoryCache> = EmbeddingService::new(provider);
let options = EmbeddingOptions::no_cache();
let embedding = service
.embed("Hello world", &options)
.await
.expect("Failed to embed");
assert_eq!(embedding.len(), 384);
}
#[tokio::test]
async fn test_service_with_cache() {
let provider = FastEmbedProvider::small().expect("Failed to init");
let cache = MemoryCache::new(100, Duration::from_secs(60));
let service = EmbeddingService::new(provider).with_cache(cache);
let options = EmbeddingOptions::default_with_cache();
let embedding1 = service
.embed("Hello world", &options)
.await
.expect("Failed first embed");
assert_eq!(service.cache_size(), 1);
let embedding2 = service
.embed("Hello world", &options)
.await
.expect("Failed second embed");
assert_eq!(embedding1, embedding2);
}
#[tokio::test]
async fn test_service_batch() {
let provider = FastEmbedProvider::small().expect("Failed to init");
let cache = MemoryCache::with_defaults();
let service = EmbeddingService::new(provider).with_cache(cache);
let texts = vec![
"Text 1".to_string(),
"Text 2".to_string(),
"Text 3".to_string(),
];
let options = EmbeddingOptions::default_with_cache();
let result = service
.embed_batch(texts.clone(), &options)
.await
.expect("Failed batch");
assert_eq!(result.embeddings.len(), 3);
assert_eq!(result.cached_count, 0);
let result2 = service
.embed_batch(texts, &options)
.await
.expect("Failed second batch");
assert_eq!(result2.cached_count, 3);
}
#[tokio::test]
async fn test_service_cache_invalidation() {
let provider = FastEmbedProvider::small().expect("Failed to init");
let cache = MemoryCache::with_defaults();
let service = EmbeddingService::new(provider).with_cache(cache);
let options = EmbeddingOptions::default_with_cache();
service
.embed("Test text", &options)
.await
.expect("Failed embed");
assert_eq!(service.cache_size(), 1);
service
.invalidate_cache("Test text")
.await
.expect("Failed to invalidate");
assert_eq!(service.cache_size(), 0);
}
#[tokio::test]
async fn test_service_clear_cache() {
let provider = FastEmbedProvider::small().expect("Failed to init");
let cache = MemoryCache::with_defaults();
let service = EmbeddingService::new(provider).with_cache(cache);
let options = EmbeddingOptions::default_with_cache();
service.embed("Test 1", &options).await.expect("Failed");
service.embed("Test 2", &options).await.expect("Failed");
assert_eq!(service.cache_size(), 2);
service.clear_cache().await.expect("Failed to clear");
assert_eq!(service.cache_size(), 0);
}
}

View File

@ -0,0 +1,430 @@
#[cfg(feature = "lancedb-store")]
use std::sync::Arc;
#[cfg(feature = "lancedb-store")]
use arrow::{
array::{
AsArray, FixedSizeListArray, Float32Array, RecordBatch, RecordBatchIterator, StringArray,
},
datatypes::{DataType, Field, Float32Type, Schema},
};
#[cfg(feature = "lancedb-store")]
use async_trait::async_trait;
#[cfg(feature = "lancedb-store")]
use futures::TryStreamExt;
#[cfg(feature = "lancedb-store")]
use lancedb::{query::ExecutableQuery, query::QueryBase, Connection, DistanceType, Table};
#[cfg(feature = "lancedb-store")]
use crate::{
error::EmbeddingError,
store::{SearchFilter, SearchResult, VectorStore, VectorStoreConfig},
traits::Embedding,
};
pub struct LanceDbStore {
#[allow(dead_code)]
connection: Connection,
table: Table,
dimensions: usize,
}
impl LanceDbStore {
pub async fn new(
path: &str,
table_name: &str,
config: VectorStoreConfig,
) -> Result<Self, EmbeddingError> {
let connection = lancedb::connect(path).execute().await.map_err(|e| {
EmbeddingError::Initialization(format!("LanceDB connection failed: {}", e))
})?;
let table = match connection.open_table(table_name).execute().await {
Ok(t) => t,
Err(_) => {
let schema = Self::create_schema(config.dimensions);
let empty_batch = RecordBatch::new_empty(Arc::clone(&schema));
let batches = RecordBatchIterator::new(
vec![Ok(empty_batch)].into_iter(),
Arc::clone(&schema),
);
connection
.create_table(table_name, batches)
.execute()
.await
.map_err(|e| {
EmbeddingError::Initialization(format!("Failed to create table: {}", e))
})?
}
};
Ok(Self {
connection,
table,
dimensions: config.dimensions,
})
}
fn create_schema(dimensions: usize) -> Arc<Schema> {
Arc::new(Schema::new(vec![
Field::new("id", DataType::Utf8, false),
Field::new(
"vector",
DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::Float32, true)),
dimensions as i32,
),
false,
),
Field::new("metadata", DataType::Utf8, true),
]))
}
fn create_record_batch(
ids: Vec<String>,
embeddings: Vec<Embedding>,
metadata: Vec<String>,
dimensions: usize,
) -> Result<RecordBatch, EmbeddingError> {
let id_array = Arc::new(StringArray::from(ids));
let metadata_array = Arc::new(StringArray::from(metadata));
let flat_values: Vec<f32> = embeddings.into_iter().flatten().collect();
let values_array = Arc::new(Float32Array::from(flat_values));
let vector_array = Arc::new(FixedSizeListArray::new(
Arc::new(Field::new("item", DataType::Float32, true)),
dimensions as i32,
values_array,
None,
));
let schema = Self::create_schema(dimensions);
RecordBatch::try_new(schema, vec![id_array, vector_array, metadata_array]).map_err(|e| {
EmbeddingError::StoreError(format!("Failed to create record batch: {}", e))
})
}
}
#[async_trait]
impl VectorStore for LanceDbStore {
fn name(&self) -> &str {
"lancedb"
}
fn dimensions(&self) -> usize {
self.dimensions
}
async fn upsert(
&self,
id: &str,
embedding: &Embedding,
metadata: serde_json::Value,
) -> Result<(), EmbeddingError> {
if embedding.len() != self.dimensions {
return Err(EmbeddingError::DimensionMismatch {
expected: self.dimensions,
actual: embedding.len(),
});
}
let metadata_str = serde_json::to_string(&metadata)
.map_err(|e| EmbeddingError::SerializationError(e.to_string()))?;
let batch = Self::create_record_batch(
vec![id.to_string()],
vec![embedding.clone()],
vec![metadata_str],
self.dimensions,
)?;
let schema = Self::create_schema(self.dimensions);
let batches = RecordBatchIterator::new(vec![Ok(batch)].into_iter(), schema);
self.table
.add(batches)
.execute()
.await
.map_err(|e| EmbeddingError::StoreError(format!("Failed to upsert: {}", e)))?;
Ok(())
}
async fn upsert_batch(
&self,
items: Vec<(String, Embedding, serde_json::Value)>,
) -> Result<(), EmbeddingError> {
if items.is_empty() {
return Ok(());
}
for (_, embedding, _) in &items {
if embedding.len() != self.dimensions {
return Err(EmbeddingError::DimensionMismatch {
expected: self.dimensions,
actual: embedding.len(),
});
}
}
let (ids, embeddings, metadata): (Vec<_>, Vec<_>, Vec<_>) = items.into_iter().fold(
(Vec::new(), Vec::new(), Vec::new()),
|(mut ids, mut embeddings, mut metadata), (id, embedding, meta)| {
ids.push(id);
embeddings.push(embedding);
metadata.push(serde_json::to_string(&meta).unwrap_or_default());
(ids, embeddings, metadata)
},
);
let batch = Self::create_record_batch(ids, embeddings, metadata, self.dimensions)?;
let schema = Self::create_schema(self.dimensions);
let batches = RecordBatchIterator::new(vec![Ok(batch)].into_iter(), schema);
self.table
.add(batches)
.execute()
.await
.map_err(|e| EmbeddingError::StoreError(format!("Failed to batch upsert: {}", e)))?;
Ok(())
}
async fn search(
&self,
embedding: &Embedding,
limit: usize,
filter: Option<SearchFilter>,
) -> Result<Vec<SearchResult>, EmbeddingError> {
if embedding.len() != self.dimensions {
return Err(EmbeddingError::DimensionMismatch {
expected: self.dimensions,
actual: embedding.len(),
});
}
let mut query = self
.table
.vector_search(embedding.as_slice())
.map_err(|e| EmbeddingError::StoreError(format!("Query setup failed: {}", e)))?
.distance_type(DistanceType::Cosine);
if let Some(ref f) = filter {
if f.min_score.is_some() {
query = query.postfilter().refine_factor(10);
}
}
let stream =
query.limit(limit).execute().await.map_err(|e| {
EmbeddingError::StoreError(format!("Query execution failed: {}", e))
})?;
let batches: Vec<RecordBatch> = stream
.try_collect()
.await
.map_err(|e| EmbeddingError::StoreError(format!("Failed to collect results: {}", e)))?;
let mut search_results = Vec::new();
for batch in batches.iter() {
let ids: &Arc<dyn arrow::array::Array> =
batch.column_by_name("id").ok_or_else(|| {
EmbeddingError::StoreError("Missing 'id' column in results".to_string())
})?;
let metadata_col: &Arc<dyn arrow::array::Array> =
batch.column_by_name("metadata").ok_or_else(|| {
EmbeddingError::StoreError("Missing 'metadata' column in results".to_string())
})?;
let distance_col: Option<&Arc<dyn arrow::array::Array>> =
batch.column_by_name("_distance");
let id_array = ids.as_string::<i32>();
let metadata_array = metadata_col.as_string::<i32>();
let num_rows: usize = batch.num_rows();
for i in 0..num_rows {
let id = id_array.value(i).to_string();
let metadata_str = metadata_array.value(i);
let metadata: serde_json::Value =
serde_json::from_str(metadata_str).unwrap_or(serde_json::json!({}));
let score = if let Some(dist_col) = distance_col {
let dist_array = dist_col.as_primitive::<Float32Type>();
1.0 - dist_array.value(i)
} else {
0.0
};
search_results.push(SearchResult {
id,
score,
embedding: None,
metadata,
});
}
}
if let Some(f) = filter {
if let Some(min_score) = f.min_score {
search_results.retain(|r| r.score >= min_score);
}
}
Ok(search_results)
}
async fn get(&self, id: &str) -> Result<Option<SearchResult>, EmbeddingError> {
let stream = self
.table
.query()
.only_if(format!("id = '{}'", id))
.limit(1)
.execute()
.await
.map_err(|e| EmbeddingError::StoreError(format!("Query execution failed: {}", e)))?;
let results: Vec<RecordBatch> = stream
.try_collect()
.await
.map_err(|e| EmbeddingError::StoreError(format!("Failed to collect results: {}", e)))?;
for batch in results.iter() {
let num_rows: usize = batch.num_rows();
if num_rows == 0 {
continue;
}
let ids: &Arc<dyn arrow::array::Array> =
batch.column_by_name("id").ok_or_else(|| {
EmbeddingError::StoreError("Missing 'id' column in results".to_string())
})?;
let metadata_col: &Arc<dyn arrow::array::Array> =
batch.column_by_name("metadata").ok_or_else(|| {
EmbeddingError::StoreError("Missing 'metadata' column in results".to_string())
})?;
let id_array = ids.as_string::<i32>();
let metadata_array = metadata_col.as_string::<i32>();
let result_id = id_array.value(0).to_string();
let metadata_str = metadata_array.value(0);
let metadata: serde_json::Value =
serde_json::from_str(metadata_str).unwrap_or(serde_json::json!({}));
return Ok(Some(SearchResult {
id: result_id,
score: 1.0,
embedding: None,
metadata,
}));
}
Ok(None)
}
async fn delete(&self, id: &str) -> Result<bool, EmbeddingError> {
self.table
.delete(&format!("id = '{}'", id))
.await
.map_err(|e| EmbeddingError::StoreError(format!("Delete failed: {}", e)))?;
Ok(true)
}
async fn flush(&self) -> Result<(), EmbeddingError> {
Ok(())
}
async fn count(&self) -> Result<usize, EmbeddingError> {
let count = self
.table
.count_rows(None)
.await
.map_err(|e| EmbeddingError::StoreError(format!("Count failed: {}", e)))?;
Ok(count)
}
}
#[cfg(test)]
mod tests {
use tempfile::tempdir;
use super::*;
#[tokio::test]
async fn test_lancedb_store_basic() {
let dir = tempdir().expect("Failed to create temp dir");
let path = dir.path().to_str().unwrap();
let config = VectorStoreConfig::new(384);
let store = LanceDbStore::new(path, "test_embeddings", config)
.await
.expect("Failed to create store");
let embedding = vec![0.1; 384];
let metadata = serde_json::json!({"text": "hello world"});
store
.upsert("test_id", &embedding, metadata.clone())
.await
.expect("Failed to upsert");
let result = store.get("test_id").await.expect("Failed to get");
assert!(result.is_some());
let search_results = store
.search(&embedding, 5, None)
.await
.expect("Failed to search");
assert!(!search_results.is_empty());
let count = store.count().await.expect("Failed to count");
assert_eq!(count, 1);
}
#[tokio::test]
async fn test_lancedb_store_batch() {
let dir = tempdir().expect("Failed to create temp dir");
let path = dir.path().to_str().unwrap();
let config = VectorStoreConfig::new(384);
let store = LanceDbStore::new(path, "test_batch", config)
.await
.expect("Failed to create store");
let items = vec![
(
"id1".to_string(),
vec![0.1; 384],
serde_json::json!({"idx": 1}),
),
(
"id2".to_string(),
vec![0.2; 384],
serde_json::json!({"idx": 2}),
),
(
"id3".to_string(),
vec![0.3; 384],
serde_json::json!({"idx": 3}),
),
];
store
.upsert_batch(items)
.await
.expect("Failed to batch upsert");
let count = store.count().await.expect("Failed to count");
assert_eq!(count, 3);
}
}

View File

@ -0,0 +1,14 @@
pub mod traits;
#[cfg(feature = "lancedb-store")]
pub mod lancedb;
#[cfg(feature = "surrealdb-store")]
pub mod surrealdb;
pub use traits::*;
#[cfg(feature = "lancedb-store")]
pub use self::lancedb::LanceDbStore;
#[cfg(feature = "surrealdb-store")]
pub use self::surrealdb::SurrealDbStore;

View File

@ -0,0 +1,298 @@
#[cfg(feature = "surrealdb-store")]
use async_trait::async_trait;
#[cfg(feature = "surrealdb-store")]
use serde::{Deserialize, Serialize};
#[cfg(feature = "surrealdb-store")]
use surrealdb::{
engine::local::{Db, Mem},
sql::Thing,
Surreal,
};
#[cfg(feature = "surrealdb-store")]
use crate::{
error::EmbeddingError,
store::{SearchFilter, SearchResult, VectorStore, VectorStoreConfig},
traits::Embedding,
};
#[derive(Debug, Serialize, Deserialize)]
struct EmbeddingRecord {
id: Option<Thing>,
vector: Vec<f32>,
metadata: serde_json::Value,
}
pub struct SurrealDbStore {
db: Surreal<Db>,
table: String,
dimensions: usize,
}
impl SurrealDbStore {
pub async fn new(
connection: Surreal<Db>,
table_name: &str,
config: VectorStoreConfig,
) -> Result<Self, EmbeddingError> {
Ok(Self {
db: connection,
table: table_name.to_string(),
dimensions: config.dimensions,
})
}
pub async fn new_memory(
table_name: &str,
config: VectorStoreConfig,
) -> Result<Self, EmbeddingError> {
let db = Surreal::new::<Mem>(()).await.map_err(|e| {
EmbeddingError::Initialization(format!("SurrealDB connection failed: {}", e))
})?;
db.use_ns("embeddings")
.use_db("embeddings")
.await
.map_err(|e| {
EmbeddingError::Initialization(format!("Failed to set namespace: {}", e))
})?;
Self::new(db, table_name, config).await
}
fn compute_cosine_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if mag_a > 0.0 && mag_b > 0.0 {
dot / (mag_a * mag_b)
} else {
0.0
}
}
}
#[async_trait]
impl VectorStore for SurrealDbStore {
fn name(&self) -> &str {
"surrealdb"
}
fn dimensions(&self) -> usize {
self.dimensions
}
async fn upsert(
&self,
id: &str,
embedding: &Embedding,
metadata: serde_json::Value,
) -> Result<(), EmbeddingError> {
if embedding.len() != self.dimensions {
return Err(EmbeddingError::DimensionMismatch {
expected: self.dimensions,
actual: embedding.len(),
});
}
let record = EmbeddingRecord {
id: None,
vector: embedding.clone(),
metadata,
};
let _: Option<EmbeddingRecord> = self
.db
.update((self.table.as_str(), id))
.content(record)
.await
.map_err(|e| EmbeddingError::StoreError(format!("Upsert failed: {}", e)))?;
Ok(())
}
async fn search(
&self,
embedding: &Embedding,
limit: usize,
filter: Option<SearchFilter>,
) -> Result<Vec<SearchResult>, EmbeddingError> {
if embedding.len() != self.dimensions {
return Err(EmbeddingError::DimensionMismatch {
expected: self.dimensions,
actual: embedding.len(),
});
}
let all_records: Vec<EmbeddingRecord> = self
.db
.select(&self.table)
.await
.map_err(|e| EmbeddingError::StoreError(format!("Search failed: {}", e)))?;
let mut scored_results: Vec<(String, f32, serde_json::Value)> = all_records
.into_iter()
.filter_map(|record| {
let id = record.id?.id.to_string();
let score = self.compute_cosine_similarity(embedding, &record.vector);
Some((id, score, record.metadata))
})
.collect();
scored_results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
if let Some(f) = filter {
if let Some(min_score) = f.min_score {
scored_results.retain(|(_, score, _)| *score >= min_score);
}
}
let results = scored_results
.into_iter()
.take(limit)
.map(|(id, score, metadata)| SearchResult {
id,
score,
embedding: None,
metadata,
})
.collect();
Ok(results)
}
async fn get(&self, id: &str) -> Result<Option<SearchResult>, EmbeddingError> {
let record: Option<EmbeddingRecord> = self
.db
.select((self.table.as_str(), id))
.await
.map_err(|e| EmbeddingError::StoreError(format!("Get failed: {}", e)))?;
Ok(record.map(|r| SearchResult {
id: id.to_string(),
score: 1.0,
embedding: Some(r.vector),
metadata: r.metadata,
}))
}
async fn delete(&self, id: &str) -> Result<bool, EmbeddingError> {
let result: Option<EmbeddingRecord> = self
.db
.delete((self.table.as_str(), id))
.await
.map_err(|e| EmbeddingError::StoreError(format!("Delete failed: {}", e)))?;
Ok(result.is_some())
}
async fn flush(&self) -> Result<(), EmbeddingError> {
Ok(())
}
async fn count(&self) -> Result<usize, EmbeddingError> {
let records: Vec<EmbeddingRecord> = self
.db
.select(&self.table)
.await
.map_err(|e| EmbeddingError::StoreError(format!("Count failed: {}", e)))?;
Ok(records.len())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_surrealdb_store_basic() {
let config = VectorStoreConfig::new(384);
let store = SurrealDbStore::new_memory("test_embeddings", config)
.await
.expect("Failed to create store");
let embedding = vec![0.1; 384];
let metadata = serde_json::json!({"text": "hello world"});
store
.upsert("test_id", &embedding, metadata.clone())
.await
.expect("Failed to upsert");
let result = store.get("test_id").await.expect("Failed to get");
assert!(result.is_some());
assert_eq!(result.as_ref().unwrap().id, "test_id");
assert_eq!(result.as_ref().unwrap().metadata, metadata);
let search_results = store
.search(&embedding, 5, None)
.await
.expect("Failed to search");
assert_eq!(search_results.len(), 1);
assert!(search_results[0].score > 0.99);
let count = store.count().await.expect("Failed to count");
assert_eq!(count, 1);
}
#[tokio::test]
async fn test_surrealdb_store_search() {
let config = VectorStoreConfig::new(3);
let store = SurrealDbStore::new_memory("test_search", config)
.await
.expect("Failed to create store");
let embeddings = [
vec![1.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.5, 0.5, 0.0],
];
for (i, embedding) in embeddings.iter().enumerate() {
store
.upsert(
&format!("id_{}", i),
embedding,
serde_json::json!({"idx": i}),
)
.await
.expect("Failed to upsert");
}
let query = vec![1.0, 0.0, 0.0];
let results = store
.search(&query, 2, None)
.await
.expect("Failed to search");
assert_eq!(results.len(), 2);
assert_eq!(results[0].id, "id_0");
assert!(results[0].score > 0.99);
}
#[tokio::test]
async fn test_surrealdb_store_delete() {
let config = VectorStoreConfig::new(384);
let store = SurrealDbStore::new_memory("test_delete", config)
.await
.expect("Failed to create store");
let embedding = vec![0.1; 384];
store
.upsert("test_id", &embedding, serde_json::json!({}))
.await
.expect("Failed to upsert");
let deleted = store.delete("test_id").await.expect("Failed to delete");
assert!(deleted);
let result = store.get("test_id").await.expect("Failed to get");
assert!(result.is_none());
}
}

View File

@ -0,0 +1,102 @@
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use crate::{error::EmbeddingError, traits::Embedding};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchResult {
pub id: String,
pub score: f32,
#[serde(skip_serializing_if = "Option::is_none")]
pub embedding: Option<Embedding>,
pub metadata: serde_json::Value,
}
#[derive(Debug, Clone, Default)]
pub struct SearchFilter {
pub metadata: Option<serde_json::Value>,
pub min_score: Option<f32>,
}
#[derive(Debug, Clone)]
pub struct VectorStoreConfig {
pub dimensions: usize,
pub metric: DistanceMetric,
pub options: serde_json::Value,
}
impl VectorStoreConfig {
pub fn new(dimensions: usize) -> Self {
Self {
dimensions,
metric: DistanceMetric::default(),
options: serde_json::json!({}),
}
}
pub fn with_metric(mut self, metric: DistanceMetric) -> Self {
self.metric = metric;
self
}
pub fn with_options(mut self, options: serde_json::Value) -> Self {
self.options = options;
self
}
}
#[derive(Debug, Clone, Copy, Default)]
pub enum DistanceMetric {
#[default]
Cosine,
Euclidean,
DotProduct,
}
#[async_trait]
pub trait VectorStore: Send + Sync {
fn name(&self) -> &str;
fn dimensions(&self) -> usize;
async fn upsert(
&self,
id: &str,
embedding: &Embedding,
metadata: serde_json::Value,
) -> Result<(), EmbeddingError>;
async fn upsert_batch(
&self,
items: Vec<(String, Embedding, serde_json::Value)>,
) -> Result<(), EmbeddingError> {
for (id, embedding, metadata) in items {
self.upsert(&id, &embedding, metadata).await?;
}
Ok(())
}
async fn search(
&self,
embedding: &Embedding,
limit: usize,
filter: Option<SearchFilter>,
) -> Result<Vec<SearchResult>, EmbeddingError>;
async fn get(&self, id: &str) -> Result<Option<SearchResult>, EmbeddingError>;
async fn delete(&self, id: &str) -> Result<bool, EmbeddingError>;
async fn delete_batch(&self, ids: &[&str]) -> Result<usize, EmbeddingError> {
let mut count = 0;
for id in ids {
if self.delete(id).await? {
count += 1;
}
}
Ok(count)
}
async fn flush(&self) -> Result<(), EmbeddingError>;
async fn count(&self) -> Result<usize, EmbeddingError>;
}

View File

@ -0,0 +1,162 @@
use async_trait::async_trait;
pub type Embedding = Vec<f32>;
#[derive(Debug, Clone)]
pub struct EmbeddingResult {
pub embeddings: Vec<Embedding>,
pub model: String,
pub dimensions: usize,
pub total_tokens: Option<u32>,
pub cached_count: usize,
}
#[derive(Debug, Clone, Default)]
pub struct EmbeddingOptions {
pub normalize: bool,
pub truncate: bool,
pub use_cache: bool,
}
impl EmbeddingOptions {
pub fn default_with_cache() -> Self {
Self {
normalize: true,
truncate: true,
use_cache: true,
}
}
pub fn no_cache() -> Self {
Self {
normalize: true,
truncate: true,
use_cache: false,
}
}
}
#[async_trait]
pub trait EmbeddingProvider: Send + Sync {
fn name(&self) -> &str;
fn model(&self) -> &str;
fn dimensions(&self) -> usize;
fn is_local(&self) -> bool;
fn max_tokens(&self) -> usize;
fn max_batch_size(&self) -> usize;
fn cost_per_1m_tokens(&self) -> f64;
async fn is_available(&self) -> bool;
async fn embed(
&self,
text: &str,
options: &EmbeddingOptions,
) -> Result<Embedding, crate::error::EmbeddingError>;
async fn embed_batch(
&self,
texts: &[&str],
options: &EmbeddingOptions,
) -> Result<EmbeddingResult, crate::error::EmbeddingError>;
}
#[derive(Debug, Clone)]
pub struct ProviderInfo {
pub name: String,
pub model: String,
pub dimensions: usize,
pub is_local: bool,
pub cost_per_1m: f64,
pub max_batch_size: usize,
}
impl<T: EmbeddingProvider> From<&T> for ProviderInfo {
fn from(provider: &T) -> Self {
Self {
name: provider.name().to_string(),
model: provider.model().to_string(),
dimensions: provider.dimensions(),
is_local: provider.is_local(),
cost_per_1m: provider.cost_per_1m_tokens(),
max_batch_size: provider.max_batch_size(),
}
}
}
pub fn normalize_embedding(embedding: &mut Embedding) {
let magnitude: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if magnitude > 0.0 {
embedding.iter_mut().for_each(|x| *x /= magnitude);
}
}
pub fn cosine_similarity(a: &Embedding, b: &Embedding) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let magnitude_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let magnitude_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if magnitude_a > 0.0 && magnitude_b > 0.0 {
dot_product / (magnitude_a * magnitude_b)
} else {
0.0
}
}
pub fn euclidean_distance(a: &Embedding, b: &Embedding) -> f32 {
if a.len() != b.len() {
return f32::MAX;
}
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f32>()
.sqrt()
}
#[cfg(test)]
mod tests {
use approx::assert_relative_eq;
use super::*;
#[test]
fn test_normalize_embedding() {
let mut embedding = vec![3.0, 4.0];
normalize_embedding(&mut embedding);
assert_relative_eq!(embedding[0], 0.6, epsilon = 0.0001);
assert_relative_eq!(embedding[1], 0.8, epsilon = 0.0001);
let magnitude: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
assert_relative_eq!(magnitude, 1.0, epsilon = 0.0001);
}
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert_relative_eq!(cosine_similarity(&a, &b), 1.0, epsilon = 0.0001);
let c = vec![0.0, 1.0, 0.0];
assert_relative_eq!(cosine_similarity(&a, &c), 0.0, epsilon = 0.0001);
let d = vec![-1.0, 0.0, 0.0];
assert_relative_eq!(cosine_similarity(&a, &d), -1.0, epsilon = 0.0001);
}
#[test]
fn test_euclidean_distance() {
let a = vec![0.0, 0.0];
let b = vec![3.0, 4.0];
assert_relative_eq!(euclidean_distance(&a, &b), 5.0, epsilon = 0.0001);
let c = vec![1.0, 1.0];
let d = vec![1.0, 1.0];
assert_relative_eq!(euclidean_distance(&c, &d), 0.0, epsilon = 0.0001);
}
}

View File

@ -0,0 +1,65 @@
[package]
name = "stratum-llm"
version = "0.1.0"
edition.workspace = true
description = "Unified LLM abstraction with CLI detection, fallback, and caching"
license.workspace = true
[dependencies]
# Async runtime
tokio = { workspace = true }
async-trait = { workspace = true }
futures = { workspace = true }
# HTTP client
reqwest = { workspace = true, features = ["stream"] }
# Serialization
serde = { workspace = true }
serde_json = { workspace = true }
serde_yaml = { workspace = true, optional = true }
# Caching
moka = { workspace = true }
# Error handling
thiserror = { workspace = true }
# Logging
tracing = { workspace = true }
# Metrics
prometheus = { workspace = true, optional = true }
# Utilities
dirs = { workspace = true }
chrono = { workspace = true }
uuid = { workspace = true, features = ["v4"] }
which = { workspace = true, optional = true }
# Hashing for cache keys
xxhash-rust = { workspace = true }
[features]
default = ["anthropic", "openai", "ollama"]
anthropic = []
openai = []
deepseek = []
ollama = []
claude-cli = []
openai-cli = []
kogral = ["serde_yaml", "which"]
metrics = ["prometheus"]
all = [
"anthropic",
"openai",
"deepseek",
"ollama",
"claude-cli",
"openai-cli",
"kogral",
"metrics",
]
[dev-dependencies]
tokio-test = { workspace = true }

View File

@ -0,0 +1,131 @@
# stratum-llm
Unified LLM abstraction for the stratumiops ecosystem with automatic provider detection, fallback chains, and smart caching.
## Features
- **Credential Auto-detection**: Automatically finds CLI credentials (Claude, OpenAI) and API keys
- **Provider Fallback**: Circuit breaker pattern with automatic failover across providers
- **Smart Caching**: xxHash-based request deduplication reduces duplicate API calls
- **Kogral Integration**: Inject project context from knowledge base (optional)
- **Cost Tracking**: Transparent cost estimation across all providers
- **Multiple Providers**: Anthropic Claude, OpenAI, DeepSeek, Ollama
## Quick Start
```rust
use stratum_llm::{UnifiedClient, Message, Role};
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let client = UnifiedClient::auto()?;
let messages = vec![
Message {
role: Role::User,
content: "What is Rust?".to_string(),
}
];
let response = client.generate(&messages, None).await?;
println!("{}", response.content);
Ok(())
}
```
## Provider Priority
1. **CLI credentials** (subscription-based, no per-token cost) - preferred
2. **API keys** from environment variables
3. **Local models** (Ollama)
The client automatically detects available credentials and builds a fallback chain.
## Features
### Default Features
```toml
[dependencies]
stratum-llm = "0.1"
```
Includes: Anthropic, OpenAI, Ollama
### All Features
```toml
[dependencies]
stratum-llm = { version = "0.1", features = ["all"] }
```
Includes: All providers, CLI detection, Kogral integration, Prometheus metrics
### Custom Feature Set
```toml
[dependencies]
stratum-llm = { version = "0.1", features = ["anthropic", "deepseek", "kogral"] }
```
Available features:
- `anthropic` - Anthropic Claude API
- `openai` - OpenAI API
- `deepseek` - DeepSeek API
- `ollama` - Ollama local models
- `claude-cli` - Claude CLI credential detection
- `kogral` - Kogral knowledge base integration
- `metrics` - Prometheus metrics
## Advanced Usage
### With Kogral Context
```rust
let client = UnifiedClient::builder()
.auto_detect()?
.with_kogral()
.build()?;
let response = client
.generate_with_kogral(&messages, None, Some("rust"), None)
.await?;
```
### Custom Fallback Strategy
```rust
use stratum_llm::{FallbackStrategy, ProviderChain};
let chain = ProviderChain::from_detected()?
.with_strategy(FallbackStrategy::OnRateLimitOrUnavailable);
let client = UnifiedClient::builder()
.with_chain(chain)
.build()?;
```
### Cost Budget
```rust
let chain = ProviderChain::from_detected()?
.with_strategy(FallbackStrategy::OnBudgetExceeded {
budget_cents: 10.0,
});
```
## Examples
Run examples with:
```bash
cargo run --example basic_usage
cargo run --example with_kogral --features kogral
cargo run --example fallback_demo
```
## License
MIT OR Apache-2.0

View File

@ -0,0 +1,43 @@
use stratum_llm::{Message, Role, UnifiedClient};
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing_subscriber::fmt::init();
println!("Creating UnifiedClient with auto-detected providers...");
let client = UnifiedClient::auto()?;
println!("\nAvailable providers:");
for provider in client.providers() {
println!(
" - {} ({}): circuit={:?}, subscription={}",
provider.name, provider.model, provider.circuit_state, provider.is_subscription
);
}
let messages = vec![Message {
role: Role::User,
content: "What is the capital of France? Answer in one word.".to_string(),
}];
println!("\nSending request...");
match client.generate(&messages, None).await {
Ok(response) => {
println!("\n✓ Success!");
println!("Provider: {}", response.provider);
println!("Model: {}", response.model);
println!("Response: {}", response.content);
println!(
"Tokens: {} in, {} out",
response.input_tokens, response.output_tokens
);
println!("Cost: ${:.4}", response.cost_cents / 100.0);
println!("Latency: {}ms", response.latency_ms);
}
Err(e) => {
eprintln!("\n✗ Error: {}", e);
}
}
Ok(())
}

View File

@ -0,0 +1,54 @@
use stratum_llm::{FallbackStrategy, Message, Role, UnifiedClient};
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing_subscriber::fmt::init();
println!("Creating UnifiedClient with budget-based fallback strategy...");
let client = UnifiedClient::builder().auto_detect()?.build()?;
let messages = vec![Message {
role: Role::User,
content: "Explain quantum computing in simple terms.".to_string(),
}];
println!("\nProvider chain:");
for (idx, provider) in client.providers().iter().enumerate() {
println!(
" {}. {} ({}) - circuit: {:?}",
idx + 1,
provider.name,
provider.model,
provider.circuit_state
);
}
println!("\nSending multiple requests to test fallback...");
for i in 1..=3 {
println!("\n--- Request {} ---", i);
match client.generate(&messages, None).await {
Ok(response) => {
println!(
"✓ Provider: {} | Model: {} | Cost: ${:.4} | Latency: {}ms",
response.provider,
response.model,
response.cost_cents / 100.0,
response.latency_ms
);
println!(
"Response preview: {}...",
&response.content[..100.min(response.content.len())]
);
}
Err(e) => {
eprintln!("✗ All providers failed: {}", e);
}
}
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
}
Ok(())
}

View File

@ -0,0 +1,45 @@
#[cfg(feature = "kogral")]
use stratum_llm::{Message, Role, UnifiedClient};
#[cfg(feature = "kogral")]
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing_subscriber::fmt::init();
println!("Creating UnifiedClient with Kogral integration...");
let client = UnifiedClient::builder()
.auto_detect()?
.with_kogral()
.build()?;
let messages = vec![Message {
role: Role::User,
content: "Write a simple Rust function to add two numbers.".to_string(),
}];
println!("\nSending request with Rust guidelines from Kogral...");
match client
.generate_with_kogral(&messages, None, Some("rust"), None)
.await
{
Ok(response) => {
println!("\n✓ Success!");
println!("Provider: {}", response.provider);
println!("Model: {}", response.model);
println!("Response:\n{}", response.content);
println!("\nCost: ${:.4}", response.cost_cents / 100.0);
println!("Latency: {}ms", response.latency_ms);
}
Err(e) => {
eprintln!("\n✗ Error: {}", e);
}
}
Ok(())
}
#[cfg(not(feature = "kogral"))]
fn main() {
eprintln!("This example requires the 'kogral' feature.");
eprintln!("Run with: cargo run --example with_kogral --features kogral");
}

3
crates/stratum-llm/src/cache/mod.rs vendored Normal file
View File

@ -0,0 +1,3 @@
pub mod request_cache;
pub use request_cache::{CacheConfig, CacheStats, CachedResponse, RequestCache};

View File

@ -0,0 +1,151 @@
use std::time::Duration;
use moka::future::Cache;
use xxhash_rust::xxh3::xxh3_64;
#[derive(Clone)]
pub struct CachedResponse {
pub content: String,
pub model: String,
pub provider: String,
pub cached_at: chrono::DateTime<chrono::Utc>,
}
pub struct RequestCache {
cache: Cache<u64, CachedResponse>,
enabled: bool,
}
#[derive(Debug, Clone)]
pub struct CacheConfig {
pub enabled: bool,
pub max_entries: u64,
pub ttl: Duration,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
enabled: true,
max_entries: 1000,
ttl: Duration::from_secs(3600),
}
}
}
impl RequestCache {
pub fn new(config: CacheConfig) -> Self {
let cache = Cache::builder()
.max_capacity(config.max_entries)
.time_to_live(config.ttl)
.build();
Self {
cache,
enabled: config.enabled,
}
}
fn compute_key(
&self,
messages: &[crate::providers::Message],
options: &crate::providers::GenerationOptions,
) -> u64 {
let mut hasher_input = String::new();
for msg in messages {
hasher_input.push_str(&format!("{:?}:{}\\n", msg.role, msg.content));
}
hasher_input.push_str(&format!(
"temp:{:?}|max:{:?}|top_p:{:?}",
options.temperature, options.max_tokens, options.top_p,
));
xxh3_64(hasher_input.as_bytes())
}
pub async fn get(
&self,
messages: &[crate::providers::Message],
options: &crate::providers::GenerationOptions,
) -> Option<CachedResponse> {
if !self.enabled {
return None;
}
let key = self.compute_key(messages, options);
self.cache.get(&key).await
}
pub async fn put(
&self,
messages: &[crate::providers::Message],
options: &crate::providers::GenerationOptions,
response: &crate::providers::GenerationResponse,
) {
if !self.enabled {
return;
}
let key = self.compute_key(messages, options);
let cached = CachedResponse {
content: response.content.clone(),
model: response.model.clone(),
provider: response.provider.clone(),
cached_at: chrono::Utc::now(),
};
self.cache.insert(key, cached).await;
}
pub async fn get_or_generate<F, Fut>(
&self,
messages: &[crate::providers::Message],
options: &crate::providers::GenerationOptions,
generate: F,
) -> Result<crate::providers::GenerationResponse, crate::error::LlmError>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<
Output = Result<crate::providers::GenerationResponse, crate::error::LlmError>,
>,
{
if let Some(cached) = self.get(messages, options).await {
tracing::debug!("Cache hit");
return Ok(crate::providers::GenerationResponse {
content: cached.content,
model: cached.model,
provider: cached.provider,
input_tokens: 0,
output_tokens: 0,
cost_cents: 0.0,
latency_ms: 0,
});
}
let response = generate().await?;
self.put(messages, options, &response).await;
Ok(response)
}
pub fn stats(&self) -> CacheStats {
CacheStats {
entry_count: self.cache.entry_count(),
hit_count: 0,
miss_count: 0,
}
}
pub fn clear(&self) {
self.cache.invalidate_all();
}
}
#[derive(Debug)]
pub struct CacheStats {
pub entry_count: u64,
pub hit_count: u64,
pub miss_count: u64,
}

View File

@ -0,0 +1,146 @@
use std::sync::atomic::{AtomicU32, AtomicU8, Ordering};
use std::sync::RwLock;
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum CircuitState {
Closed = 0,
Open = 1,
HalfOpen = 2,
}
pub struct CircuitBreaker {
state: AtomicU8,
failure_count: AtomicU32,
success_count: AtomicU32,
config: CircuitBreakerConfig,
last_failure_time: RwLock<Option<Instant>>,
last_success_time: RwLock<Option<Instant>>,
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
pub success_threshold: u32,
pub reset_timeout: Duration,
pub request_timeout: Duration,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
success_threshold: 3,
reset_timeout: Duration::from_secs(30),
request_timeout: Duration::from_secs(60),
}
}
}
impl CircuitBreaker {
pub fn new(config: CircuitBreakerConfig) -> Self {
Self {
state: AtomicU8::new(CircuitState::Closed as u8),
failure_count: AtomicU32::new(0),
success_count: AtomicU32::new(0),
config,
last_failure_time: RwLock::new(None),
last_success_time: RwLock::new(None),
}
}
pub fn state(&self) -> CircuitState {
match self.state.load(Ordering::SeqCst) {
0 => CircuitState::Closed,
1 => CircuitState::Open,
2 => CircuitState::HalfOpen,
_ => CircuitState::Closed,
}
}
pub fn should_allow(&self) -> bool {
match self.state() {
CircuitState::Closed => true,
CircuitState::Open => {
if let Some(last_failure) = *self.last_failure_time.read().unwrap() {
if last_failure.elapsed() >= self.config.reset_timeout {
self.state
.store(CircuitState::HalfOpen as u8, Ordering::SeqCst);
self.success_count.store(0, Ordering::SeqCst);
return true;
}
}
false
}
CircuitState::HalfOpen => true,
}
}
pub fn record_success(&self) {
*self.last_success_time.write().unwrap() = Some(Instant::now());
self.failure_count.store(0, Ordering::SeqCst);
if self.state() == CircuitState::HalfOpen {
let count = self.success_count.fetch_add(1, Ordering::SeqCst) + 1;
if count >= self.config.success_threshold {
self.state
.store(CircuitState::Closed as u8, Ordering::SeqCst);
tracing::info!("Circuit breaker closed (recovered)");
}
}
}
pub fn record_failure(&self) {
*self.last_failure_time.write().unwrap() = Some(Instant::now());
match self.state() {
CircuitState::Closed => {
let count = self.failure_count.fetch_add(1, Ordering::SeqCst) + 1;
if count >= self.config.failure_threshold {
self.state.store(CircuitState::Open as u8, Ordering::SeqCst);
tracing::warn!(
failures = count,
"Circuit breaker opened (too many failures)"
);
}
}
CircuitState::HalfOpen => {
self.state.store(CircuitState::Open as u8, Ordering::SeqCst);
self.success_count.store(0, Ordering::SeqCst);
tracing::warn!("Circuit breaker reopened (half-open test failed)");
}
CircuitState::Open => {}
}
}
pub async fn call<F, T, E>(&self, f: F) -> Result<T, CircuitError<E>>
where
F: std::future::Future<Output = Result<T, E>>,
{
if !self.should_allow() {
return Err(CircuitError::Open);
}
match tokio::time::timeout(self.config.request_timeout, f).await {
Ok(Ok(result)) => {
self.record_success();
Ok(result)
}
Ok(Err(e)) => {
self.record_failure();
Err(CircuitError::Inner(e))
}
Err(_) => {
self.record_failure();
Err(CircuitError::Timeout)
}
}
}
}
#[derive(Debug)]
pub enum CircuitError<E> {
Open,
Timeout,
Inner(E),
}

View File

@ -0,0 +1,5 @@
pub mod circuit_breaker;
pub mod provider_chain;
pub use circuit_breaker::{CircuitBreaker, CircuitBreakerConfig, CircuitError, CircuitState};
pub use provider_chain::{FallbackStrategy, ProviderChain, ProviderInfo};

View File

@ -0,0 +1,208 @@
use crate::chain::circuit_breaker::{CircuitBreaker, CircuitBreakerConfig, CircuitError};
use crate::credentials::CredentialDetector;
use crate::error::LlmError;
use crate::providers::{
ConfiguredProvider, GenerationOptions, GenerationResponse, LlmProvider, Message,
};
pub struct ProviderChain {
providers: Vec<ProviderWithCircuit>,
strategy: FallbackStrategy,
}
struct ProviderWithCircuit {
provider: Box<dyn LlmProvider>,
circuit: CircuitBreaker,
is_subscription: bool,
priority: u32,
}
#[derive(Clone)]
pub enum FallbackStrategy {
Sequential,
OnRateLimitOrUnavailable,
OnBudgetExceeded { budget_cents: f64 },
}
impl ProviderChain {
pub fn from_detected() -> Result<Self, LlmError> {
let detector = CredentialDetector::new();
let credentials = detector.detect_all();
if credentials.is_empty() {
return Err(LlmError::NoProvidersAvailable);
}
let mut providers = Vec::new();
for cred in credentials {
let provider: Option<Box<dyn LlmProvider>> = match cred.provider.as_str() {
#[cfg(feature = "anthropic")]
"anthropic" => Some(Box::new(
crate::providers::AnthropicProvider::sonnet()
.map_err(|_| LlmError::NoProvidersAvailable)?,
)),
#[cfg(feature = "openai")]
"openai" => Some(Box::new(
crate::providers::OpenAiProvider::gpt4o()
.map_err(|_| LlmError::NoProvidersAvailable)?,
)),
#[cfg(feature = "deepseek")]
"deepseek" => Some(Box::new(
crate::providers::DeepSeekProvider::coder()
.map_err(|_| LlmError::NoProvidersAvailable)?,
)),
#[cfg(feature = "ollama")]
"ollama" => Some(Box::new(crate::providers::OllamaProvider::default())),
_ => None,
};
if let Some(provider) = provider {
providers.push(ProviderWithCircuit {
provider,
circuit: CircuitBreaker::new(CircuitBreakerConfig::default()),
is_subscription: cred.is_subscription,
priority: if cred.is_subscription { 0 } else { 10 },
});
}
}
if providers.is_empty() {
return Err(LlmError::NoProvidersAvailable);
}
providers.sort_by_key(|p| p.priority);
Ok(Self {
providers,
strategy: FallbackStrategy::Sequential,
})
}
pub fn with_providers(providers: Vec<ConfiguredProvider>) -> Self {
let providers = providers
.into_iter()
.map(|p| ProviderWithCircuit {
provider: p.provider,
circuit: CircuitBreaker::new(CircuitBreakerConfig::default()),
is_subscription: matches!(
p.credential_source,
crate::providers::CredentialSource::Cli { .. }
),
priority: p.priority,
})
.collect();
Self {
providers,
strategy: FallbackStrategy::Sequential,
}
}
pub fn with_strategy(mut self, strategy: FallbackStrategy) -> Self {
self.strategy = strategy;
self
}
pub async fn generate(
&self,
messages: &[Message],
options: &GenerationOptions,
) -> Result<GenerationResponse, LlmError> {
let mut last_error: Option<LlmError> = None;
for pwc in &self.providers {
if !pwc.circuit.should_allow() {
tracing::debug!(provider = pwc.provider.name(), "Circuit open, skipping");
continue;
}
if let FallbackStrategy::OnBudgetExceeded { budget_cents } = &self.strategy {
let estimated_cost = pwc.provider.estimate_cost(
estimate_tokens(messages),
options.max_tokens.unwrap_or(1000),
);
if estimated_cost > *budget_cents {
tracing::debug!(
provider = pwc.provider.name(),
estimated_cost,
budget = budget_cents,
"Would exceed budget, trying next"
);
continue;
}
}
match pwc
.circuit
.call(pwc.provider.generate(messages, options))
.await
{
Ok(response) => {
tracing::info!(
provider = pwc.provider.name(),
model = pwc.provider.model(),
cost_cents = response.cost_cents,
latency_ms = response.latency_ms,
"Request successful"
);
return Ok(response);
}
Err(CircuitError::Open) => {
tracing::debug!(provider = pwc.provider.name(), "Circuit open");
continue;
}
Err(CircuitError::Timeout) => {
tracing::warn!(provider = pwc.provider.name(), "Request timed out");
last_error = Some(LlmError::Timeout);
continue;
}
Err(CircuitError::Inner(e)) => {
tracing::warn!(provider = pwc.provider.name(), error = %e, "Request failed");
let should_fallback = match &self.strategy {
FallbackStrategy::Sequential => true,
FallbackStrategy::OnRateLimitOrUnavailable => {
matches!(e, LlmError::RateLimit(_) | LlmError::Unavailable(_))
}
FallbackStrategy::OnBudgetExceeded { .. } => true,
};
if should_fallback {
last_error = Some(e);
continue;
} else {
return Err(e);
}
}
}
}
Err(last_error.unwrap_or(LlmError::NoProvidersAvailable))
}
pub fn provider_info(&self) -> Vec<ProviderInfo> {
self.providers
.iter()
.map(|p| ProviderInfo {
name: p.provider.name().to_string(),
model: p.provider.model().to_string(),
is_subscription: p.is_subscription,
circuit_state: p.circuit.state(),
})
.collect()
}
}
#[derive(Debug)]
pub struct ProviderInfo {
pub name: String,
pub model: String,
pub is_subscription: bool,
pub circuit_state: crate::chain::circuit_breaker::CircuitState,
}
fn estimate_tokens(messages: &[Message]) -> u32 {
let total_chars: usize = messages.iter().map(|m| m.content.len()).sum();
(total_chars / 4) as u32
}

View File

@ -0,0 +1,167 @@
use crate::cache::{CacheConfig, RequestCache};
use crate::chain::ProviderChain;
use crate::error::LlmError;
#[cfg(feature = "kogral")]
use crate::kogral::KogralIntegration;
use crate::providers::{GenerationOptions, GenerationResponse, Message};
pub struct UnifiedClient {
chain: ProviderChain,
cache: RequestCache,
#[cfg(feature = "kogral")]
kogral: Option<KogralIntegration>,
default_options: GenerationOptions,
}
pub struct UnifiedClientBuilder {
chain: Option<ProviderChain>,
cache_config: CacheConfig,
#[cfg(feature = "kogral")]
kogral: Option<KogralIntegration>,
default_options: GenerationOptions,
}
impl UnifiedClientBuilder {
pub fn new() -> Self {
Self {
chain: None,
cache_config: CacheConfig::default(),
#[cfg(feature = "kogral")]
kogral: None,
default_options: GenerationOptions::default(),
}
}
pub fn auto_detect(mut self) -> Result<Self, LlmError> {
self.chain = Some(ProviderChain::from_detected()?);
Ok(self)
}
pub fn with_chain(mut self, chain: ProviderChain) -> Self {
self.chain = Some(chain);
self
}
pub fn with_cache(mut self, config: CacheConfig) -> Self {
self.cache_config = config;
self
}
pub fn without_cache(mut self) -> Self {
self.cache_config.enabled = false;
self
}
#[cfg(feature = "kogral")]
pub fn with_kogral(mut self) -> Self {
self.kogral = KogralIntegration::new();
self
}
pub fn with_defaults(mut self, options: GenerationOptions) -> Self {
self.default_options = options;
self
}
pub fn build(self) -> Result<UnifiedClient, LlmError> {
let chain = self.chain.ok_or(LlmError::NoProvidersAvailable)?;
Ok(UnifiedClient {
chain,
cache: RequestCache::new(self.cache_config),
#[cfg(feature = "kogral")]
kogral: self.kogral,
default_options: self.default_options,
})
}
}
impl Default for UnifiedClientBuilder {
fn default() -> Self {
Self::new()
}
}
impl UnifiedClient {
pub fn auto() -> Result<Self, LlmError> {
UnifiedClientBuilder::new().auto_detect()?.build()
}
pub fn builder() -> UnifiedClientBuilder {
UnifiedClientBuilder::new()
}
pub async fn generate(
&self,
messages: &[Message],
options: Option<&GenerationOptions>,
) -> Result<GenerationResponse, LlmError> {
let opts = options.unwrap_or(&self.default_options);
self.cache
.get_or_generate(messages, opts, || self.chain.generate(messages, opts))
.await
}
#[cfg(feature = "kogral")]
pub async fn generate_with_kogral(
&self,
messages: &[Message],
options: Option<&GenerationOptions>,
language: Option<&str>,
domain: Option<&str>,
) -> Result<GenerationResponse, LlmError> {
let opts = options.unwrap_or(&self.default_options);
let enriched_messages = if let Some(kogral) = &self.kogral {
let mut ctx = serde_json::json!({});
kogral
.enrich_context(&mut ctx, language, domain)
.await
.map_err(|e| LlmError::Context(e.to_string()))?;
self.inject_kogral_context(messages, &ctx)
} else {
messages.to_vec()
};
self.generate(&enriched_messages, Some(opts)).await
}
#[cfg(feature = "kogral")]
fn inject_kogral_context(
&self,
messages: &[Message],
kogral_ctx: &serde_json::Value,
) -> Vec<Message> {
let mut result = Vec::with_capacity(messages.len());
let mut system_found = false;
for msg in messages {
if matches!(msg.role, crate::providers::Role::System) && !system_found {
let enhanced_content = format!(
"{}\n\n## Project Context (from Kogral)\n{}",
msg.content,
serde_json::to_string_pretty(kogral_ctx).unwrap_or_default()
);
result.push(Message {
role: msg.role,
content: enhanced_content,
});
system_found = true;
} else {
result.push(msg.clone());
}
}
result
}
pub fn providers(&self) -> Vec<crate::chain::ProviderInfo> {
self.chain.provider_info()
}
pub fn clear_cache(&self) {
self.cache.clear();
}
}

View File

@ -0,0 +1,68 @@
#[cfg(feature = "claude-cli")]
use crate::credentials::detector::DetectedCredential;
#[cfg(feature = "claude-cli")]
use crate::providers::CredentialSource;
#[cfg(feature = "claude-cli")]
impl crate::credentials::CredentialDetector {
pub fn detect_claude_cli(&self) -> Option<DetectedCredential> {
let config_dir = dirs::config_dir()?;
let possible_paths = [
config_dir.join("claude").join("credentials.json"),
config_dir.join("claude-cli").join("auth.json"),
config_dir.join("anthropic").join("credentials.json"),
];
for path in &possible_paths {
if let Some(cred) = Self::try_read_claude_credentials(path) {
return Some(cred);
}
}
None
}
fn try_read_claude_credentials(path: &std::path::Path) -> Option<DetectedCredential> {
if !path.exists() {
return None;
}
let content = std::fs::read_to_string(path).ok()?;
let json: serde_json::Value = serde_json::from_str(&content).ok()?;
let token = json.get("access_token")?;
if !token.is_string() {
return None;
}
if Self::is_token_expired(&json) {
tracing::debug!("Claude CLI token expired");
return None;
}
Some(DetectedCredential {
provider: "anthropic".to_string(),
source: CredentialSource::Cli {
path: path.to_path_buf(),
},
is_subscription: true,
})
}
fn is_token_expired(json: &serde_json::Value) -> bool {
let Some(expires) = json.get("expires_at") else {
return false;
};
let Some(exp_str) = expires.as_str() else {
return false;
};
let Ok(exp) = chrono::DateTime::parse_from_rfc3339(exp_str) else {
return false;
};
exp < chrono::Utc::now()
}
}

View File

@ -0,0 +1,130 @@
use crate::providers::CredentialSource;
#[derive(Debug, Clone)]
pub struct DetectedCredential {
pub provider: String,
pub source: CredentialSource,
pub is_subscription: bool,
}
pub struct CredentialDetector {
check_cli: bool,
check_env: bool,
}
impl CredentialDetector {
pub fn new() -> Self {
Self {
check_cli: true,
check_env: true,
}
}
pub fn without_cli(mut self) -> Self {
self.check_cli = false;
self
}
pub fn detect_all(&self) -> Vec<DetectedCredential> {
let mut credentials = Vec::new();
if self.check_cli {
#[cfg(feature = "claude-cli")]
if let Some(cred) = self.detect_claude_cli() {
credentials.push(cred);
}
}
if self.check_env {
if let Some(cred) = self.detect_anthropic_env() {
credentials.push(cred);
}
if let Some(cred) = self.detect_openai_env() {
credentials.push(cred);
}
if let Some(cred) = self.detect_deepseek_env() {
credentials.push(cred);
}
}
if let Some(cred) = self.detect_ollama() {
credentials.push(cred);
}
credentials
}
pub fn detect_for_provider(&self, provider: &str) -> Option<DetectedCredential> {
match provider {
"anthropic" | "claude" => {
#[cfg(feature = "claude-cli")]
if self.check_cli {
if let Some(cred) = self.detect_claude_cli() {
return Some(cred);
}
}
self.detect_anthropic_env()
}
"openai" => self.detect_openai_env(),
"deepseek" => self.detect_deepseek_env(),
"ollama" => self.detect_ollama(),
_ => None,
}
}
pub fn detect_anthropic_env(&self) -> Option<DetectedCredential> {
if std::env::var("ANTHROPIC_API_KEY").is_ok() {
Some(DetectedCredential {
provider: "anthropic".to_string(),
source: CredentialSource::EnvVar {
name: "ANTHROPIC_API_KEY".to_string(),
},
is_subscription: false,
})
} else {
None
}
}
pub fn detect_openai_env(&self) -> Option<DetectedCredential> {
if std::env::var("OPENAI_API_KEY").is_ok() {
Some(DetectedCredential {
provider: "openai".to_string(),
source: CredentialSource::EnvVar {
name: "OPENAI_API_KEY".to_string(),
},
is_subscription: false,
})
} else {
None
}
}
pub fn detect_deepseek_env(&self) -> Option<DetectedCredential> {
if std::env::var("DEEPSEEK_API_KEY").is_ok() {
Some(DetectedCredential {
provider: "deepseek".to_string(),
source: CredentialSource::EnvVar {
name: "DEEPSEEK_API_KEY".to_string(),
},
is_subscription: false,
})
} else {
None
}
}
pub fn detect_ollama(&self) -> Option<DetectedCredential> {
Some(DetectedCredential {
provider: "ollama".to_string(),
source: CredentialSource::None,
is_subscription: false,
})
}
}
impl Default for CredentialDetector {
fn default() -> Self {
Self::new()
}
}

View File

@ -0,0 +1,6 @@
pub mod detector;
#[cfg(feature = "claude-cli")]
pub mod claude_cli;
pub use detector::{CredentialDetector, DetectedCredential};

View File

@ -0,0 +1,47 @@
use thiserror::Error;
#[derive(Debug, Error)]
pub enum LlmError {
#[error("No providers available")]
NoProvidersAvailable,
#[error("Missing credential: {0}")]
MissingCredential(String),
#[error("Network error: {0}")]
Network(String),
#[error("API error: {0}")]
Api(String),
#[error("Rate limited: {0}")]
RateLimit(String),
#[error("Provider unavailable: {0}")]
Unavailable(String),
#[error("Request timeout")]
Timeout,
#[error("Parse error: {0}")]
Parse(String),
#[error("Context error: {0}")]
Context(String),
#[error("Circuit breaker open for provider")]
CircuitOpen,
}
impl LlmError {
pub fn is_rate_limit(&self) -> bool {
matches!(self, Self::RateLimit(_))
}
pub fn is_retriable(&self) -> bool {
matches!(
self,
Self::Network(_) | Self::RateLimit(_) | Self::Timeout | Self::Unavailable(_)
)
}
}

View File

@ -0,0 +1,216 @@
#[cfg(feature = "kogral")]
use std::path::PathBuf;
#[cfg(feature = "kogral")]
pub struct KogralIntegration {
kogral_path: PathBuf,
}
#[cfg(feature = "kogral")]
impl KogralIntegration {
pub fn new() -> Option<Self> {
let possible_paths = [
dirs::home_dir()?.join(".kogral"),
PathBuf::from("/Users/Akasha/Development/kogral/.kogral"),
];
for path in &possible_paths {
if path.exists() {
return Some(Self {
kogral_path: path.clone(),
});
}
}
None
}
pub fn with_path(path: impl Into<PathBuf>) -> Self {
Self {
kogral_path: path.into(),
}
}
pub async fn get_guidelines(&self, language: &str) -> Result<Vec<Guideline>, KogralError> {
let guidelines_dir = self.kogral_path.join("default");
let mut guidelines = Vec::new();
if let Ok(entries) = std::fs::read_dir(&guidelines_dir) {
for entry in entries.flatten() {
let path = entry.path();
if path.extension().is_none_or(|e| e != "md") {
continue;
}
let Ok(content) = std::fs::read_to_string(&path) else {
continue;
};
if let Some(guideline) = Self::parse_guideline(&content, language) {
guidelines.push(guideline);
}
}
}
Ok(guidelines)
}
pub async fn get_patterns(&self, domain: &str) -> Result<Vec<Pattern>, KogralError> {
let patterns_dir = self.kogral_path.join("default");
let mut patterns = Vec::new();
if let Ok(entries) = std::fs::read_dir(&patterns_dir) {
for entry in entries.flatten() {
let path = entry.path();
if path.extension().is_none_or(|e| e != "md") {
continue;
}
let Ok(content) = std::fs::read_to_string(&path) else {
continue;
};
if let Some(pattern) = Self::parse_pattern(&content, domain) {
patterns.push(pattern);
}
}
}
Ok(patterns)
}
pub async fn enrich_context(
&self,
context: &mut serde_json::Value,
language: Option<&str>,
domain: Option<&str>,
) -> Result<(), KogralError> {
let mut kogral_context = serde_json::json!({});
if let Some(lang) = language {
let guidelines = self.get_guidelines(lang).await?;
if !guidelines.is_empty() {
kogral_context["guidelines"] = serde_json::to_value(&guidelines)?;
}
}
if let Some(dom) = domain {
let patterns = self.get_patterns(dom).await?;
if !patterns.is_empty() {
kogral_context["patterns"] = serde_json::to_value(&patterns)?;
}
}
if let Some(obj) = context.as_object_mut() {
obj.insert("kogral".to_string(), kogral_context);
}
Ok(())
}
fn parse_guideline(content: &str, language: &str) -> Option<Guideline> {
let (frontmatter, body) = Self::split_frontmatter(content)?;
let meta: serde_yaml::Value = serde_yaml::from_str(&frontmatter).ok()?;
let node_type = meta.get("node_type")?.as_str()?;
if node_type != "guideline" {
return None;
}
let tags: Vec<String> = meta
.get("tags")?
.as_sequence()?
.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect();
if !tags.iter().any(|t| t.eq_ignore_ascii_case(language)) {
return None;
}
Some(Guideline {
title: meta.get("title")?.as_str()?.to_string(),
content: body.to_string(),
tags,
})
}
fn parse_pattern(content: &str, domain: &str) -> Option<Pattern> {
let (frontmatter, body) = Self::split_frontmatter(content)?;
let meta: serde_yaml::Value = serde_yaml::from_str(&frontmatter).ok()?;
let node_type = meta.get("node_type")?.as_str()?;
if node_type != "pattern" {
return None;
}
let tags: Vec<String> = meta
.get("tags")?
.as_sequence()?
.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect();
if !tags.iter().any(|t| t.eq_ignore_ascii_case(domain)) {
return None;
}
Some(Pattern {
title: meta.get("title")?.as_str()?.to_string(),
content: body.to_string(),
tags,
})
}
fn split_frontmatter(content: &str) -> Option<(String, String)> {
let content = content.trim();
if !content.starts_with("---") {
return None;
}
let after_first = &content[3..];
let end = after_first.find("---")?;
let frontmatter = after_first[..end].trim().to_string();
let body = after_first[end + 3..].trim().to_string();
Some((frontmatter, body))
}
}
#[cfg(feature = "kogral")]
impl Default for KogralIntegration {
fn default() -> Self {
Self::new().unwrap_or_else(|| Self {
kogral_path: PathBuf::from(".kogral"),
})
}
}
#[cfg(feature = "kogral")]
#[derive(Debug, Clone, serde::Serialize)]
pub struct Guideline {
pub title: String,
pub content: String,
pub tags: Vec<String>,
}
#[cfg(feature = "kogral")]
#[derive(Debug, Clone, serde::Serialize)]
pub struct Pattern {
pub title: String,
pub content: String,
pub tags: Vec<String>,
}
#[cfg(feature = "kogral")]
#[derive(Debug, thiserror::Error)]
pub enum KogralError {
#[error("Kogral not found")]
NotFound,
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Parse error: {0}")]
Parse(String),
#[error("Serialization error: {0}")]
Serialization(#[from] serde_json::Error),
}

View File

@ -0,0 +1,5 @@
#[cfg(feature = "kogral")]
pub mod integration;
#[cfg(feature = "kogral")]
pub use integration::{Guideline, KogralError, KogralIntegration, Pattern};

View File

@ -0,0 +1,64 @@
//! Unified LLM abstraction with CLI detection, fallback, and caching
//!
//! # Features
//!
//! - **Credential auto-detection**: Finds CLI credentials (Claude, OpenAI) and
//! API keys
//! - **Provider fallback**: Automatic failover with circuit breaker pattern
//! - **Smart caching**: xxHash-based deduplication reduces duplicate API calls
//! - **Kogral integration**: Inject project context from knowledge base
//! - **Cost tracking**: Transparent cost estimation across providers
//!
//! # Quick Start
//!
//! ```no_run
//! use stratum_llm::{UnifiedClient, Message, Role};
//!
//! #[tokio::main]
//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
//! let client = UnifiedClient::auto()?;
//!
//! let messages = vec![
//! Message {
//! role: Role::User,
//! content: "What is Rust?".to_string(),
//! }
//! ];
//!
//! let response = client.generate(&messages, None).await?;
//! println!("{}", response.content);
//!
//! Ok(())
//! }
//! ```
pub mod cache;
pub mod chain;
pub mod client;
pub mod credentials;
pub mod error;
pub mod kogral;
pub mod metrics;
pub mod providers;
pub use cache::{CacheConfig, RequestCache};
pub use chain::{CircuitBreakerConfig, FallbackStrategy, ProviderChain};
pub use client::{UnifiedClient, UnifiedClientBuilder};
pub use credentials::{CredentialDetector, DetectedCredential};
pub use error::LlmError;
#[cfg(feature = "kogral")]
pub use kogral::{Guideline, KogralIntegration, Pattern};
#[cfg(feature = "metrics")]
pub use metrics::LlmMetrics;
#[cfg(feature = "anthropic")]
pub use providers::AnthropicProvider;
#[cfg(feature = "deepseek")]
pub use providers::DeepSeekProvider;
#[cfg(feature = "ollama")]
pub use providers::OllamaProvider;
#[cfg(feature = "openai")]
pub use providers::OpenAiProvider;
pub use providers::{
ConfiguredProvider, CredentialSource, GenerationOptions, GenerationResponse, LlmProvider,
Message, Role,
};

View File

@ -0,0 +1,74 @@
#[cfg(feature = "metrics")]
use prometheus::{Counter, Histogram, IntGauge, Registry};
#[cfg(feature = "metrics")]
pub struct LlmMetrics {
pub requests_total: Counter,
pub requests_success: Counter,
pub requests_failed: Counter,
pub cache_hits: Counter,
pub cache_misses: Counter,
pub circuit_opens: Counter,
pub fallbacks: Counter,
pub latency_seconds: Histogram,
pub cost_cents: Counter,
pub active_circuits_open: IntGauge,
}
#[cfg(feature = "metrics")]
impl LlmMetrics {
pub fn new() -> Self {
Self {
requests_total: Counter::new("stratum_llm_requests_total", "Total LLM requests")
.unwrap(),
requests_success: Counter::new(
"stratum_llm_requests_success_total",
"Successful LLM requests",
)
.unwrap(),
requests_failed: Counter::new(
"stratum_llm_requests_failed_total",
"Failed LLM requests",
)
.unwrap(),
cache_hits: Counter::new("stratum_llm_cache_hits_total", "Cache hits").unwrap(),
cache_misses: Counter::new("stratum_llm_cache_misses_total", "Cache misses").unwrap(),
circuit_opens: Counter::new("stratum_llm_circuit_opens_total", "Circuit breaker opens")
.unwrap(),
fallbacks: Counter::new("stratum_llm_fallbacks_total", "Provider fallbacks").unwrap(),
latency_seconds: Histogram::with_opts(
prometheus::HistogramOpts::new("stratum_llm_latency_seconds", "Request latency")
.buckets(vec![0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 30.0]),
)
.unwrap(),
cost_cents: Counter::new("stratum_llm_cost_cents_total", "Total cost in cents")
.unwrap(),
active_circuits_open: IntGauge::new(
"stratum_llm_circuits_open",
"Currently open circuit breakers",
)
.unwrap(),
}
}
pub fn register(&self, registry: &Registry) -> Result<(), prometheus::Error> {
registry.register(Box::new(self.requests_total.clone()))?;
registry.register(Box::new(self.requests_success.clone()))?;
registry.register(Box::new(self.requests_failed.clone()))?;
registry.register(Box::new(self.cache_hits.clone()))?;
registry.register(Box::new(self.cache_misses.clone()))?;
registry.register(Box::new(self.circuit_opens.clone()))?;
registry.register(Box::new(self.fallbacks.clone()))?;
registry.register(Box::new(self.latency_seconds.clone()))?;
registry.register(Box::new(self.cost_cents.clone()))?;
registry.register(Box::new(self.active_circuits_open.clone()))?;
Ok(())
}
}
#[cfg(feature = "metrics")]
impl Default for LlmMetrics {
fn default() -> Self {
Self::new()
}
}

View File

@ -0,0 +1,181 @@
use async_trait::async_trait;
use crate::error::LlmError;
use crate::providers::{
GenerationOptions, GenerationResponse, LlmProvider, Message, StreamResponse,
};
#[cfg(feature = "anthropic")]
pub struct AnthropicProvider {
client: reqwest::Client,
api_key: String,
model: String,
}
#[cfg(feature = "anthropic")]
impl AnthropicProvider {
const BASE_URL: &'static str = "https://api.anthropic.com/v1";
const API_VERSION: &'static str = "2023-06-01";
pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
Self {
client: reqwest::Client::new(),
api_key: api_key.into(),
model: model.into(),
}
}
pub fn from_env(model: impl Into<String>) -> Result<Self, LlmError> {
let api_key = std::env::var("ANTHROPIC_API_KEY")
.map_err(|_| LlmError::MissingCredential("ANTHROPIC_API_KEY".to_string()))?;
Ok(Self::new(api_key, model))
}
pub fn sonnet() -> Result<Self, LlmError> {
Self::from_env("claude-sonnet-4-5-20250929")
}
pub fn opus() -> Result<Self, LlmError> {
Self::from_env("claude-opus-4-5-20251101")
}
pub fn haiku() -> Result<Self, LlmError> {
Self::from_env("claude-haiku-4-5-20251001")
}
}
#[cfg(feature = "anthropic")]
#[async_trait]
impl LlmProvider for AnthropicProvider {
fn name(&self) -> &str {
"anthropic"
}
fn model(&self) -> &str {
&self.model
}
async fn is_available(&self) -> bool {
true
}
async fn generate(
&self,
messages: &[Message],
options: &GenerationOptions,
) -> Result<GenerationResponse, LlmError> {
let start = std::time::Instant::now();
let (system, user_messages): (Vec<_>, Vec<_>) = messages
.iter()
.partition(|m| matches!(m.role, crate::providers::Role::System));
let system_content = system.first().map(|m| m.content.as_str());
let mut body = serde_json::json!({
"model": self.model,
"messages": user_messages.iter().map(|m| {
serde_json::json!({
"role": match m.role {
crate::providers::Role::User => "user",
crate::providers::Role::Assistant => "assistant",
crate::providers::Role::System => "user",
},
"content": m.content,
})
}).collect::<Vec<_>>(),
"max_tokens": options.max_tokens.unwrap_or(4096),
});
if let Some(sys) = system_content {
body["system"] = serde_json::json!(sys);
}
if let Some(temp) = options.temperature {
body["temperature"] = serde_json::json!(temp);
}
if let Some(top_p) = options.top_p {
body["top_p"] = serde_json::json!(top_p);
}
if !options.stop_sequences.is_empty() {
body["stop_sequences"] = serde_json::json!(options.stop_sequences);
}
let response = self
.client
.post(format!("{}/messages", Self::BASE_URL))
.header("x-api-key", &self.api_key)
.header("anthropic-version", Self::API_VERSION)
.header("content-type", "application/json")
.json(&body)
.send()
.await
.map_err(|e| LlmError::Network(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_default();
if status.as_u16() == 429 {
return Err(LlmError::RateLimit(text));
}
return Err(LlmError::Api(format!("{}: {}", status, text)));
}
let json: serde_json::Value = response
.json()
.await
.map_err(|e| LlmError::Parse(e.to_string()))?;
let content = json["content"][0]["text"]
.as_str()
.unwrap_or("")
.to_string();
let input_tokens = json["usage"]["input_tokens"].as_u64().unwrap_or(0) as u32;
let output_tokens = json["usage"]["output_tokens"].as_u64().unwrap_or(0) as u32;
Ok(GenerationResponse {
content,
model: self.model.clone(),
provider: "anthropic".to_string(),
input_tokens,
output_tokens,
cost_cents: self.estimate_cost(input_tokens, output_tokens),
latency_ms: start.elapsed().as_millis() as u64,
})
}
async fn stream(
&self,
_messages: &[Message],
_options: &GenerationOptions,
) -> Result<StreamResponse, LlmError> {
Err(LlmError::Unavailable(
"Streaming not yet implemented".to_string(),
))
}
fn estimate_cost(&self, input_tokens: u32, output_tokens: u32) -> f64 {
let input_cost = (input_tokens as f64 / 1_000_000.0) * self.cost_per_1m_input();
let output_cost = (output_tokens as f64 / 1_000_000.0) * self.cost_per_1m_output();
(input_cost + output_cost) * 100.0
}
fn cost_per_1m_input(&self) -> f64 {
match self.model.as_str() {
m if m.contains("opus") => 15.0,
m if m.contains("sonnet") => 3.0,
m if m.contains("haiku") => 1.0,
_ => 3.0,
}
}
fn cost_per_1m_output(&self) -> f64 {
match self.model.as_str() {
m if m.contains("opus") => 75.0,
m if m.contains("sonnet") => 15.0,
m if m.contains("haiku") => 5.0,
_ => 15.0,
}
}
}

View File

@ -0,0 +1,147 @@
use async_trait::async_trait;
use crate::error::LlmError;
use crate::providers::{
GenerationOptions, GenerationResponse, LlmProvider, Message, StreamResponse,
};
#[cfg(feature = "deepseek")]
pub struct DeepSeekProvider {
client: reqwest::Client,
api_key: String,
model: String,
}
#[cfg(feature = "deepseek")]
impl DeepSeekProvider {
const BASE_URL: &'static str = "https://api.deepseek.com/v1";
pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
Self {
client: reqwest::Client::new(),
api_key: api_key.into(),
model: model.into(),
}
}
pub fn from_env(model: impl Into<String>) -> Result<Self, LlmError> {
let api_key = std::env::var("DEEPSEEK_API_KEY")
.map_err(|_| LlmError::MissingCredential("DEEPSEEK_API_KEY".to_string()))?;
Ok(Self::new(api_key, model))
}
pub fn coder() -> Result<Self, LlmError> {
Self::from_env("deepseek-coder")
}
pub fn chat() -> Result<Self, LlmError> {
Self::from_env("deepseek-chat")
}
}
#[cfg(feature = "deepseek")]
#[async_trait]
impl LlmProvider for DeepSeekProvider {
fn name(&self) -> &str {
"deepseek"
}
fn model(&self) -> &str {
&self.model
}
async fn is_available(&self) -> bool {
true
}
async fn generate(
&self,
messages: &[Message],
options: &GenerationOptions,
) -> Result<GenerationResponse, LlmError> {
let start = std::time::Instant::now();
let body = serde_json::json!({
"model": self.model,
"messages": messages.iter().map(|m| {
serde_json::json!({
"role": match m.role {
crate::providers::Role::System => "system",
crate::providers::Role::User => "user",
crate::providers::Role::Assistant => "assistant",
},
"content": m.content,
})
}).collect::<Vec<_>>(),
"max_tokens": options.max_tokens.unwrap_or(4096),
"temperature": options.temperature.unwrap_or(0.7),
});
let response = self
.client
.post(format!("{}/chat/completions", Self::BASE_URL))
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.map_err(|e| LlmError::Network(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_default();
if status.as_u16() == 429 {
return Err(LlmError::RateLimit(text));
}
return Err(LlmError::Api(format!("{}: {}", status, text)));
}
let json: serde_json::Value = response
.json()
.await
.map_err(|e| LlmError::Parse(e.to_string()))?;
let content = json["choices"][0]["message"]["content"]
.as_str()
.unwrap_or("")
.to_string();
let input_tokens = json["usage"]["prompt_tokens"].as_u64().unwrap_or(0) as u32;
let output_tokens = json["usage"]["completion_tokens"].as_u64().unwrap_or(0) as u32;
Ok(GenerationResponse {
content,
model: self.model.clone(),
provider: "deepseek".to_string(),
input_tokens,
output_tokens,
cost_cents: self.estimate_cost(input_tokens, output_tokens),
latency_ms: start.elapsed().as_millis() as u64,
})
}
async fn stream(
&self,
_messages: &[Message],
_options: &GenerationOptions,
) -> Result<StreamResponse, LlmError> {
Err(LlmError::Unavailable(
"Streaming not yet implemented".to_string(),
))
}
fn estimate_cost(&self, input_tokens: u32, output_tokens: u32) -> f64 {
let input_cost = (input_tokens as f64 / 1_000_000.0) * self.cost_per_1m_input();
let output_cost = (output_tokens as f64 / 1_000_000.0) * self.cost_per_1m_output();
(input_cost + output_cost) * 100.0
}
fn cost_per_1m_input(&self) -> f64 {
0.14
}
fn cost_per_1m_output(&self) -> f64 {
0.28
}
}

View File

@ -0,0 +1,23 @@
pub mod traits;
#[cfg(feature = "anthropic")]
pub mod anthropic;
#[cfg(feature = "deepseek")]
pub mod deepseek;
#[cfg(feature = "ollama")]
pub mod ollama;
#[cfg(feature = "openai")]
pub mod openai;
#[cfg(feature = "anthropic")]
pub use anthropic::AnthropicProvider;
#[cfg(feature = "deepseek")]
pub use deepseek::DeepSeekProvider;
#[cfg(feature = "ollama")]
pub use ollama::OllamaProvider;
#[cfg(feature = "openai")]
pub use openai::OpenAiProvider;
pub use traits::{
ConfiguredProvider, CredentialSource, GenerationOptions, GenerationResponse, LlmProvider,
Message, Role, StreamChunk, StreamResponse,
};

View File

@ -0,0 +1,159 @@
use async_trait::async_trait;
use crate::error::LlmError;
use crate::providers::{
GenerationOptions, GenerationResponse, LlmProvider, Message, StreamResponse,
};
#[cfg(feature = "ollama")]
pub struct OllamaProvider {
client: reqwest::Client,
base_url: String,
model: String,
}
#[cfg(feature = "ollama")]
impl OllamaProvider {
pub fn new(base_url: impl Into<String>, model: impl Into<String>) -> Self {
Self {
client: reqwest::Client::new(),
base_url: base_url.into(),
model: model.into(),
}
}
pub fn from_env(model: impl Into<String>) -> Self {
let base_url =
std::env::var("OLLAMA_HOST").unwrap_or_else(|_| "http://localhost:11434".to_string());
Self::new(base_url, model)
}
pub fn llama3() -> Self {
Self::from_env("llama3")
}
pub fn codellama() -> Self {
Self::from_env("codellama")
}
pub fn mistral() -> Self {
Self::from_env("mistral")
}
}
#[cfg(feature = "ollama")]
impl Default for OllamaProvider {
fn default() -> Self {
Self::llama3()
}
}
#[cfg(feature = "ollama")]
#[async_trait]
impl LlmProvider for OllamaProvider {
fn name(&self) -> &str {
"ollama"
}
fn model(&self) -> &str {
&self.model
}
async fn is_available(&self) -> bool {
self.client
.get(format!("{}/api/tags", self.base_url))
.send()
.await
.is_ok()
}
async fn generate(
&self,
messages: &[Message],
options: &GenerationOptions,
) -> Result<GenerationResponse, LlmError> {
let start = std::time::Instant::now();
let prompt = messages
.iter()
.map(|m| match m.role {
crate::providers::Role::System => format!("System: {}", m.content),
crate::providers::Role::User => format!("User: {}", m.content),
crate::providers::Role::Assistant => format!("Assistant: {}", m.content),
})
.collect::<Vec<_>>()
.join("\n\n");
let mut body = serde_json::json!({
"model": self.model,
"prompt": prompt,
"stream": false,
});
if let Some(temp) = options.temperature {
body["temperature"] = serde_json::json!(temp);
}
if let Some(top_p) = options.top_p {
body["top_p"] = serde_json::json!(top_p);
}
if !options.stop_sequences.is_empty() {
body["stop"] = serde_json::json!(options.stop_sequences);
}
let response = self
.client
.post(format!("{}/api/generate", self.base_url))
.json(&body)
.send()
.await
.map_err(|e| LlmError::Network(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_default();
return Err(LlmError::Api(format!("{}: {}", status, text)));
}
let json: serde_json::Value = response
.json()
.await
.map_err(|e| LlmError::Parse(e.to_string()))?;
let content = json["response"].as_str().unwrap_or("").to_string();
let input_tokens = prompt.len() as u32 / 4;
let output_tokens = content.len() as u32 / 4;
Ok(GenerationResponse {
content,
model: self.model.clone(),
provider: "ollama".to_string(),
input_tokens,
output_tokens,
cost_cents: 0.0,
latency_ms: start.elapsed().as_millis() as u64,
})
}
async fn stream(
&self,
_messages: &[Message],
_options: &GenerationOptions,
) -> Result<StreamResponse, LlmError> {
Err(LlmError::Unavailable(
"Streaming not yet implemented".to_string(),
))
}
fn estimate_cost(&self, _input_tokens: u32, _output_tokens: u32) -> f64 {
0.0
}
fn cost_per_1m_input(&self) -> f64 {
0.0
}
fn cost_per_1m_output(&self) -> f64 {
0.0
}
}

View File

@ -0,0 +1,161 @@
use async_trait::async_trait;
use crate::error::LlmError;
use crate::providers::{
GenerationOptions, GenerationResponse, LlmProvider, Message, StreamResponse,
};
#[cfg(feature = "openai")]
pub struct OpenAiProvider {
client: reqwest::Client,
api_key: String,
model: String,
}
#[cfg(feature = "openai")]
impl OpenAiProvider {
const BASE_URL: &'static str = "https://api.openai.com/v1";
pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
Self {
client: reqwest::Client::new(),
api_key: api_key.into(),
model: model.into(),
}
}
pub fn from_env(model: impl Into<String>) -> Result<Self, LlmError> {
let api_key = std::env::var("OPENAI_API_KEY")
.map_err(|_| LlmError::MissingCredential("OPENAI_API_KEY".to_string()))?;
Ok(Self::new(api_key, model))
}
pub fn gpt4o() -> Result<Self, LlmError> {
Self::from_env("gpt-4o")
}
pub fn gpt4_turbo() -> Result<Self, LlmError> {
Self::from_env("gpt-4-turbo")
}
pub fn gpt35_turbo() -> Result<Self, LlmError> {
Self::from_env("gpt-3.5-turbo")
}
}
#[cfg(feature = "openai")]
#[async_trait]
impl LlmProvider for OpenAiProvider {
fn name(&self) -> &str {
"openai"
}
fn model(&self) -> &str {
&self.model
}
async fn is_available(&self) -> bool {
true
}
async fn generate(
&self,
messages: &[Message],
options: &GenerationOptions,
) -> Result<GenerationResponse, LlmError> {
let start = std::time::Instant::now();
let body = serde_json::json!({
"model": self.model,
"messages": messages.iter().map(|m| {
serde_json::json!({
"role": match m.role {
crate::providers::Role::System => "system",
crate::providers::Role::User => "user",
crate::providers::Role::Assistant => "assistant",
},
"content": m.content,
})
}).collect::<Vec<_>>(),
"max_tokens": options.max_tokens.unwrap_or(4096),
"temperature": options.temperature.unwrap_or(0.7),
});
let response = self
.client
.post(format!("{}/chat/completions", Self::BASE_URL))
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.map_err(|e| LlmError::Network(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_default();
if status.as_u16() == 429 {
return Err(LlmError::RateLimit(text));
}
return Err(LlmError::Api(format!("{}: {}", status, text)));
}
let json: serde_json::Value = response
.json()
.await
.map_err(|e| LlmError::Parse(e.to_string()))?;
let content = json["choices"][0]["message"]["content"]
.as_str()
.unwrap_or("")
.to_string();
let input_tokens = json["usage"]["prompt_tokens"].as_u64().unwrap_or(0) as u32;
let output_tokens = json["usage"]["completion_tokens"].as_u64().unwrap_or(0) as u32;
Ok(GenerationResponse {
content,
model: self.model.clone(),
provider: "openai".to_string(),
input_tokens,
output_tokens,
cost_cents: self.estimate_cost(input_tokens, output_tokens),
latency_ms: start.elapsed().as_millis() as u64,
})
}
async fn stream(
&self,
_messages: &[Message],
_options: &GenerationOptions,
) -> Result<StreamResponse, LlmError> {
Err(LlmError::Unavailable(
"Streaming not yet implemented".to_string(),
))
}
fn estimate_cost(&self, input_tokens: u32, output_tokens: u32) -> f64 {
let input_cost = (input_tokens as f64 / 1_000_000.0) * self.cost_per_1m_input();
let output_cost = (output_tokens as f64 / 1_000_000.0) * self.cost_per_1m_output();
(input_cost + output_cost) * 100.0
}
fn cost_per_1m_input(&self) -> f64 {
match self.model.as_str() {
"gpt-4o" => 5.0,
"gpt-4-turbo" => 10.0,
"gpt-3.5-turbo" => 0.5,
_ => 5.0,
}
}
fn cost_per_1m_output(&self) -> f64 {
match self.model.as_str() {
"gpt-4o" => 15.0,
"gpt-4-turbo" => 30.0,
"gpt-3.5-turbo" => 1.5,
_ => 15.0,
}
}
}

View File

@ -0,0 +1,95 @@
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: Role,
pub content: String,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Role {
System,
User,
Assistant,
}
#[derive(Debug, Clone, Default)]
pub struct GenerationOptions {
pub temperature: Option<f32>,
pub max_tokens: Option<u32>,
pub top_p: Option<f32>,
pub stop_sequences: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct GenerationResponse {
pub content: String,
pub model: String,
pub provider: String,
pub input_tokens: u32,
pub output_tokens: u32,
pub cost_cents: f64,
pub latency_ms: u64,
}
pub type StreamChunk = String;
pub type StreamResponse = std::pin::Pin<
Box<dyn futures::Stream<Item = Result<StreamChunk, crate::error::LlmError>> + Send>,
>;
#[async_trait]
pub trait LlmProvider: Send + Sync {
/// Provider name (e.g., "anthropic", "openai", "ollama")
fn name(&self) -> &str;
/// Model identifier
fn model(&self) -> &str;
/// Check if provider is available and configured
async fn is_available(&self) -> bool;
/// Generate a completion
async fn generate(
&self,
messages: &[Message],
options: &GenerationOptions,
) -> Result<GenerationResponse, crate::error::LlmError>;
/// Stream a completion (future work)
async fn stream(
&self,
messages: &[Message],
options: &GenerationOptions,
) -> Result<StreamResponse, crate::error::LlmError>;
/// Estimate cost for a request (before sending)
fn estimate_cost(&self, input_tokens: u32, output_tokens: u32) -> f64;
/// Cost per 1M input tokens in cents
fn cost_per_1m_input(&self) -> f64;
/// Cost per 1M output tokens in cents
fn cost_per_1m_output(&self) -> f64;
}
/// Credential source for a provider
#[derive(Debug, Clone)]
pub enum CredentialSource {
/// CLI tool credentials (subscription-based, no per-token cost)
Cli { path: std::path::PathBuf },
/// API key from environment variable
EnvVar { name: String },
/// API key from config file
ConfigFile { path: std::path::PathBuf },
/// No credentials needed (local provider)
None,
}
/// Provider with credential metadata
pub struct ConfiguredProvider {
pub provider: Box<dyn LlmProvider>,
pub credential_source: CredentialSource,
pub priority: u32,
}

View File

@ -32,6 +32,15 @@ Infrastructure automation and deployment tools.
See [Operations Portfolio Docs](en/ops/) for technical details. See [Operations Portfolio Docs](en/ops/) for technical details.
### Architecture
Cross-cutting architectural decisions documented as ADRs.
- [ADR-001: Stratum-Embeddings](en/architecture/adrs/001-stratum-embeddings.md) - Unified embedding library
- [ADR-002: Stratum-LLM](en/architecture/adrs/002-stratum-llm.md) - Unified LLM provider library
See [Architecture Docs](en/architecture/) for all ADRs.
## Quick Start ## Quick Start
1. Choose your language: [English](en/) | [Español](es/) 1. Choose your language: [English](en/) | [Español](es/)
@ -47,3 +56,4 @@ Each language directory contains:
- `stratiumiops-technical-specs.md` - Technical specifications - `stratiumiops-technical-specs.md` - Technical specifications
- `ia/` - AI portfolio documentation - `ia/` - AI portfolio documentation
- `ops/` - Operations portfolio documentation - `ops/` - Operations portfolio documentation
- `architecture/` - Architecture documentation and ADRs

View File

@ -34,6 +34,16 @@ Infrastructure automation and deployment tools.
See [ops/](ops/) directory for full operations portfolio documentation. See [ops/](ops/) directory for full operations portfolio documentation.
### Architecture
Architectural decisions and ecosystem design.
- [**ADRs**](architecture/adrs/) - Architecture Decision Records
- [ADR-001: Stratum-Embeddings](architecture/adrs/001-stratum-embeddings.md) - Unified embedding library
- [ADR-002: Stratum-LLM](architecture/adrs/002-stratum-llm.md) - Unified LLM provider library
See [architecture/](architecture/) directory for full architecture documentation.
## Navigation ## Navigation
- [Back to root documentation](../) - [Back to root documentation](../)

View File

@ -0,0 +1,30 @@
# Architecture
Architecture documentation for the STRATUMIOPS ecosystem.
## Contents
### ADRs (Architecture Decision Records)
Documented architectural decisions following the ADR format:
- [**ADR-001: Stratum-Embeddings**](adrs/001-stratum-embeddings.md) - Unified embedding library
- [**ADR-002: Stratum-LLM**](adrs/002-stratum-llm.md) - Unified LLM provider library
## ADR Format
Each ADR follows this structure:
| Section | Description |
| --------------- | ------------------------------------------ |
| Status | Proposed, Accepted, Deprecated, Superseded |
| Context | Problem and current state |
| Decision | Chosen solution |
| Rationale | Why this solution |
| Consequences | Positive, negative, mitigations |
| Success Metrics | How to measure the outcome |
## Navigation
- [Back to main documentation](../)
- [Spanish version](../../es/architecture/)

View File

@ -0,0 +1,279 @@
# ADR-001: Stratum-Embeddings - Unified Embedding Library
## Status
**Proposed**
## Context
### Current State: Fragmented Implementations
The ecosystem has 3 independent embedding implementations:
| Project | Location | Providers | Caching |
| ------------ | ------------------------------------- | ----------------------------- | ------- |
| Kogral | `kogral-core/src/embeddings/` | fastembed, rig-core (partial) | No |
| Provisioning | `provisioning-rag/src/embeddings.rs` | OpenAI direct | No |
| Vapora | `vapora-llm-router/src/embeddings.rs` | OpenAI, HuggingFace, Ollama | No |
### Identified Problems
#### 1. Duplicated Code
Each project reimplements:
- HTTP client for OpenAI embeddings
- JSON response parsing
- Error handling
- Token estimation
**Impact**: ~400 duplicated lines, inconsistent error handling.
#### 2. No Caching
Embeddings regenerated every time:
```text
"What is Rust?" → OpenAI → 1536 dims → $0.00002
"What is Rust?" → OpenAI → 1536 dims → $0.00002 (same result)
"What is Rust?" → OpenAI → 1536 dims → $0.00002 (same result)
```
**Impact**: Unnecessary costs, additional latency, more frequent rate limits.
#### 3. No Fallback
If OpenAI fails, everything fails. No fallback to local alternatives (fastembed, Ollama).
**Impact**: Reduced availability, total dependency on one provider.
#### 4. Silent Dimension Mismatch
Different providers produce different dimensions:
| Provider | Model | Dimensions |
| --------- | ---------------------- | ---------- |
| fastembed | bge-small-en | 384 |
| fastembed | bge-large-en | 1024 |
| OpenAI | text-embedding-3-small | 1536 |
| OpenAI | text-embedding-3-large | 3072 |
| Ollama | nomic-embed-text | 768 |
**Impact**: Corrupt vector indices if provider changes.
#### 5. No Metrics
No visibility into usage, cache hit rate, latency per provider, or accumulated costs.
## Decision
Create `stratum-embeddings` as a unified crate that:
1. **Unifies** implementations from Kogral, Provisioning, and Vapora
2. **Adds caching** to avoid recomputing identical embeddings
3. **Implements fallback** between providers (cloud → local)
4. **Clearly documents** dimensions and limitations per provider
5. **Exposes metrics** for observability
6. **Provides VectorStore trait** with LanceDB and SurrealDB backends based on project needs
### Storage Backend Decision
Each project chooses its vector storage backend based on priority:
| Project | Backend | Priority | Justification |
| ------------ | --------- | -------------- | -------------------------------------------------- |
| Kogral | SurrealDB | Graph richness | Knowledge Graph needs unified graph+vector queries |
| Provisioning | LanceDB | Vector scale | RAG with millions of document chunks |
| Vapora | LanceDB | Vector scale | Execution traces, pattern matching at scale |
#### Why SurrealDB for Kogral
Kogral is a Knowledge Graph where relationships are the primary value.
With hybrid architecture (LanceDB vectors + SurrealDB graph), a typical query would require:
1. LanceDB: vector search → candidate_ids
2. SurrealDB: graph filter on candidates → results
3. App layer: merge, re-rank, deduplication
**Accepted trade-off**: SurrealDB has worse pure vector performance than LanceDB,
but Kogral's scale is limited by human curation of knowledge (typically 10K-100K concepts).
#### Why LanceDB for Provisioning and Vapora
| Aspect | SurrealDB | LanceDB |
| --------------- | ---------- | -------------------- |
| Storage format | Row-based | Columnar (Lance) |
| Vector index | HNSW (RAM) | IVF-PQ (disk-native) |
| Practical scale | Millions | Billions |
| Compression | ~1x | ~32x (PQ) |
| Zero-copy read | No | Yes |
### Architecture
```text
┌─────────────────────────────────────────────────────────────────┐
│ stratum-embeddings │
├─────────────────────────────────────────────────────────────────┤
│ EmbeddingProvider trait │
│ ├─ embed(text) → Vec<f32>
│ ├─ embed_batch(texts) → Vec<Vec<f32>> │
│ ├─ dimensions() → usize │
│ └─ is_local() → bool │
│ │
│ ┌───────────┐ ┌───────────┐ ┌───────────┐ │
│ │ FastEmbed │ │ OpenAI │ │ Ollama │ │
│ │ (local) │ │ (cloud) │ │ (local) │ │
│ └───────────┘ └───────────┘ └───────────┘ │
│ └────────────┬────────────┘ │
│ ▼ │
│ EmbeddingCache (memory/disk) │
│ │ │
│ ▼ │
│ EmbeddingService │
│ │ │
│ ▼ │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ VectorStore trait │ │
│ │ ├─ upsert(id, embedding, metadata) │ │
│ │ ├─ search(embedding, limit, filter) → Vec<Match> │ │
│ │ └─ delete(id) │ │
│ └──────────────────────────────────────────────────────────┘ │
│ │ │ │
│ ▼ ▼ │
│ ┌─────────────────┐ ┌─────────────────┐ │
│ │ SurrealDbStore │ │ LanceDbStore │ │
│ │ (Kogral) │ │ (Prov/Vapora) │ │
│ └─────────────────┘ └─────────────────┘ │
└─────────────────────────────────────────────────────────────────┘
```
## Rationale
### Why Caching is Critical
For a typical RAG system (10,000 chunks):
- **Without cache**: Re-indexing and repeated queries multiply costs
- **With cache**: First indexing pays, rest are cache hits
**Estimated savings**: 60-80% in embedding costs.
### Why Fallback is Important
| Scenario | Without Fallback | With Fallback |
| ----------------- | ---------------- | -------------------- |
| OpenAI rate limit | ERROR | → fastembed (local) |
| OpenAI downtime | ERROR | → Ollama (local) |
| No internet | ERROR | → fastembed (local) |
### Why Local Providers First
For development: fastembed loads local model (~100MB), no API keys required, no costs, works offline.
For production: OpenAI for quality, fastembed as fallback.
## Consequences
### Positive
1. Single source of truth for the entire ecosystem
2. 60-80% fewer embedding API calls (caching)
3. High availability with local providers (fallback)
4. Usage and cost metrics
5. Feature-gated: only compile what you need
6. Storage flexibility: VectorStore trait allows choosing backend per project
### Negative
1. **Dimension lock-in**: Changing provider requires re-indexing
2. **Cache invalidation**: Updated content may serve stale embeddings
3. **Model download**: fastembed downloads ~100MB on first use
4. **Storage lock-in per project**: Kogral tied to SurrealDB, others to LanceDB
### Mitigations
| Negative | Mitigation |
| ----------------- | ---------------------------------------------- |
| Dimension lock-in | Document clearly, warn on provider change |
| Stale cache | Configurable TTL, bypass option |
| Model download | Show progress, cache in ~/.cache/fastembed |
| Storage lock-in | Conscious decision based on project priorities |
## Success Metrics
| Metric | Current | Target |
| ------------------------- | ------- | ------ |
| Duplicate implementations | 3 | 1 |
| Cache hit rate | 0% | >60% |
| Fallback availability | 0% | 100% |
| Cost per 10K embeddings | ~$0.20 | ~$0.05 |
## Provider Selection Guide
### Development
```rust
// Local, free, offline
let service = EmbeddingService::builder()
.with_provider(FastEmbedProvider::small()?) // 384 dims
.with_memory_cache()
.build()?;
```
### Production (Quality)
```rust
// OpenAI with local fallback
let service = EmbeddingService::builder()
.with_provider(OpenAiEmbeddingProvider::large()?) // 3072 dims
.with_provider(FastEmbedProvider::large()?) // Fallback
.with_memory_cache()
.build()?;
```
### Production (Cost-Optimized)
```rust
// OpenAI small with fallback
let service = EmbeddingService::builder()
.with_provider(OpenAiEmbeddingProvider::small()?) // 1536 dims
.with_provider(OllamaEmbeddingProvider::nomic()) // Fallback
.with_memory_cache()
.build()?;
```
## Dimension Compatibility Matrix
| If using... | Can switch to... | CANNOT switch to... |
| ---------------------- | --------------------------- | ------------------- |
| fastembed small (384) | fastembed small, all-minilm | Any other |
| fastembed large (1024) | fastembed large | Any other |
| OpenAI small (1536) | OpenAI small, ada-002 | Any other |
| OpenAI large (3072) | OpenAI large | Any other |
**Rule**: Only switch between models with the SAME dimensions.
## Implementation Priority
| Order | Feature | Reason |
| ----- | ----------------------- | -------------------------- |
| 1 | EmbeddingProvider trait | Foundation for everything |
| 2 | FastEmbed provider | Works without API keys |
| 3 | Memory cache | Biggest cost impact |
| 4 | VectorStore trait | Storage abstraction |
| 5 | SurrealDbStore | Kogral needs graph+vector |
| 6 | LanceDbStore | Provisioning/Vapora scale |
| 7 | OpenAI provider | Production |
| 8 | Ollama provider | Local fallback |
| 9 | Batch processing | Efficiency |
| 10 | Metrics | Observability |
## References
**Existing Implementations**:
- Kogral: `kogral-core/src/embeddings/`
- Vapora: `vapora-llm-router/src/embeddings.rs`
- Provisioning: `provisioning/platform/crates/rag/src/embeddings.rs`
**Target Location**: `stratumiops/crates/stratum-embeddings/`

View File

@ -0,0 +1,279 @@
# ADR-002: Stratum-LLM - Unified LLM Provider Library
## Status
**Proposed**
## Context
### Current State: Fragmented LLM Connections
The stratumiops ecosystem has 4 projects with AI functionality, each with its own implementation:
| Project | Implementation | Providers | Duplication |
| ------------ | -------------------------- | ---------------------- | ------------------- |
| Vapora | `typedialog-ai` (path dep) | Claude, OpenAI, Ollama | Shared base |
| TypeDialog | `typedialog-ai` (local) | Claude, OpenAI, Ollama | Defines abstraction |
| Provisioning | Custom `LlmClient` | Claude, OpenAI | 100% duplicated |
| Kogral | `rig-core` | Embeddings only | Different stack |
### Identified Problems
#### 1. Code Duplication
Provisioning reimplements what TypeDialog already has:
- reqwest HTTP client
- Headers: x-api-key, anthropic-version
- JSON body formatting
- Response parsing
- Error handling
**Impact**: ~500 duplicated lines, bugs fixed in one place don't propagate.
#### 2. API Keys Only, No CLI Detection
No project detects credentials from official CLIs:
```text
Claude CLI: ~/.config/claude/credentials.json
OpenAI CLI: ~/.config/openai/credentials.json
```
**Impact**: Users with Claude Pro/Max ($20-100/month) pay for API tokens when they could use their subscription.
#### 3. No Automatic Fallback
When a provider fails (rate limit, timeout), the request fails completely:
```text
Actual: Request → Claude API → Rate Limit → ERROR
Desired: Request → Claude API → Rate Limit → OpenAI → Success
```
#### 4. No Circuit Breaker
If Claude API is down, each request attempts to connect, fails, and propagates the error:
```text
Request 1 → Claude → Timeout (30s) → Error
Request 2 → Claude → Timeout (30s) → Error
Request 3 → Claude → Timeout (30s) → Error
```
**Impact**: Accumulated latency, degraded UX.
#### 5. No Caching
Identical requests always go to the API:
```text
"Explain this Rust error" → Claude → $0.003
"Explain this Rust error" → Claude → $0.003 (same result)
```
**Impact**: Unnecessary costs, especially in development/testing.
#### 6. Kogral Not Integrated
Kogral has guidelines and patterns that could enrich LLM context, but there's no integration.
## Decision
Create `stratum-llm` as a unified crate that:
1. **Consolidates** existing implementations from typedialog-ai and provisioning
2. **Detects** CLI credentials and subscriptions before using API keys
3. **Implements** automatic fallback with circuit breaker
4. **Adds** request caching to reduce costs
5. **Integrates** Kogral for context enrichment
6. **Is used** by all ecosystem projects
### Architecture
```text
┌─────────────────────────────────────────────────────────┐
│ stratum-llm │
├─────────────────────────────────────────────────────────┤
│ CredentialDetector │
│ ├─ Claude CLI → ~/.config/claude/ (subscription) │
│ ├─ OpenAI CLI → ~/.config/openai/ │
│ ├─ Env vars → *_API_KEY │
│ └─ Ollama → localhost:11434 (free) │
│ │ │
│ ▼ │
│ ProviderChain (ordered by priority) │
│ [CLI/Sub] → [API] → [DeepSeek] → [Ollama] │
│ │ │ │ │ │
│ └──────────┴─────────┴───────────┘ │
│ │ │
│ CircuitBreaker per provider │
│ │ │
│ RequestCache │
│ │ │
│ KogralIntegration │
│ │ │
│ UnifiedClient │
│ │
└─────────────────────────────────────────────────────────┘
```
## Rationale
### Why Not Use Another External Crate
| Alternative | Why Not |
| -------------- | ------------------------------------------ |
| kaccy-ai | Oriented toward blockchain/fraud detection |
| llm (crate) | Very basic, no circuit breaker or caching |
| langchain-rust | Python port, not idiomatic Rust |
| rig-core | Embeddings/RAG only, no chat completion |
**Best option**: Build on typedialog-ai and add missing features.
### Why CLI Detection is Important
Cost analysis for typical user:
| Scenario | Monthly Cost |
| ------------------------- | -------------------- |
| API only (current) | ~$840 |
| Claude Pro + API overflow | ~$20 + ~$200 = $220 |
| Claude Max + API overflow | ~$100 + ~$50 = $150 |
**Potential savings**: 70-80% by detecting and using subscriptions first.
### Why Circuit Breaker
Without circuit breaker, a downed provider causes:
- N requests × 30s timeout = N×30s total latency
- All resources occupied waiting for timeouts
With circuit breaker:
- First failure opens circuit
- Following requests fail immediately (fast fail)
- Fallback to another provider without waiting
- Circuit resets after cooldown
### Why Caching
For typical development:
- Same questions repeated while iterating
- Testing executes same prompts multiple times
Estimated cache hit rate: 15-30% in active development.
### Why Kogral Integration
Kogral has language guidelines, domain patterns, and ADRs.
Without integration the LLM generates generic code;
with integration it generates code following project conventions.
## Consequences
### Positive
1. Single source of truth for LLM logic
2. CLI detection reduces costs 70-80%
3. Circuit breaker + fallback = high availability
4. 15-30% fewer requests in development (caching)
5. Kogral improves generation quality
6. Feature-gated: each feature is optional
### Negative
1. **Migration effort**: Refactor Vapora, TypeDialog, Provisioning
2. **New dependency**: Projects depend on stratumiops
3. **CLI auth complexity**: Different credential formats per version
4. **Cache invalidation**: Stale responses if not managed well
### Mitigations
| Negative | Mitigation |
| ------------------- | ------------------------------------------- |
| Migration effort | Re-export compatible API from typedialog-ai |
| New dependency | Local path dependency, not crates.io |
| CLI auth complexity | Version detection, fallback to API if fails |
| Cache invalidation | Configurable TTL, bypass option |
## Success Metrics
| Metric | Current | Target |
| ------------------------ | ------- | --------------- |
| Duplicated lines of code | ~500 | 0 |
| CLI credential detection | 0% | 100% |
| Fallback success rate | 0% | >90% |
| Cache hit rate | 0% | 15-30% |
| Latency (provider down) | 30s+ | <1s (fast fail) |
## Cost Impact Analysis
Based on real usage data ($840/month):
| Scenario | Savings |
| -------------------------- | ------------------ |
| CLI detection (Claude Max) | ~$700/month |
| Caching (15% hit rate) | ~$50/month |
| DeepSeek fallback for code | ~$100/month |
| **Total potential** | **$500-700/month** |
## Migration Strategy
### Migration Phases
1. Create stratum-llm with API compatible with typedialog-ai
2. typedialog-ai re-exports stratum-llm (backward compatible)
3. Vapora migrates to stratum-llm directly
4. Provisioning migrates its LlmClient to stratum-llm
5. Deprecate typedialog-ai, consolidate in stratum-llm
### Feature Adoption
| Feature | Adoption |
| --------------- | ----------------------------------------- |
| Basic providers | Immediate (direct replacement) |
| CLI detection | Optional, feature flag |
| Circuit breaker | Default on |
| Caching | Default on, configurable TTL |
| Kogral | Feature flag, requires Kogral installed |
## Alternatives Considered
### Alternative 1: Improve typedialog-ai In-Place
**Pros**: No new crate required
**Cons**: TypeDialog is a specific project, not shared infrastructure
**Decision**: stratum-llm in stratumiops is better location for cross-project infrastructure.
### Alternative 2: Use LiteLLM (Python) as Proxy
**Pros**: Very complete, 100+ providers
**Cons**: Python dependency, proxy latency, not Rust-native
**Decision**: Keep pure Rust stack.
### Alternative 3: Each Project Maintains Its Own Implementation
**Pros**: Independence
**Cons**: Duplication, inconsistency, bugs not shared
**Decision**: Consolidation is better long-term.
## References
**Existing Implementations**:
- TypeDialog: `typedialog/crates/typedialog-ai/`
- Vapora: `vapora/crates/vapora-llm-router/`
- Provisioning: `provisioning/platform/crates/rag/`
**Kogral**: `kogral/`
**Target Location**: `stratumiops/crates/stratum-llm/`

View File

@ -0,0 +1,22 @@
# ADRs - Architecture Decision Records
Architecture decision records for the STRATUMIOPS ecosystem.
## Active ADRs
| ID | Title | Status |
| -------------------------------- | --------------------------------------------- | -------- |
| [001](001-stratum-embeddings.md) | Stratum-Embeddings: Unified Embedding Library | Proposed |
| [002](002-stratum-llm.md) | Stratum-LLM: Unified LLM Provider Library | Proposed |
## Statuses
- **Proposed**: Under review, pending implementation
- **Accepted**: Approved and implemented
- **Deprecated**: Replaced by another ADR
- **Superseded**: Obsolete, see replacement ADR
## Navigation
- [Back to architecture](../)
- [Spanish version](../../../es/architecture/adrs/)

View File

@ -34,6 +34,16 @@ Herramientas de automatización de infraestructura y despliegue.
Ver directorio [ops/](ops/) para documentación completa del portfolio de operaciones. Ver directorio [ops/](ops/) para documentación completa del portfolio de operaciones.
### Arquitectura
Decisiones arquitecturales y diseño del ecosistema.
- [**ADRs**](architecture/adrs/) - Architecture Decision Records
- [ADR-001: Stratum-Embeddings](architecture/adrs/001-stratum-embeddings.md) - Biblioteca unificada de embeddings
- [ADR-002: Stratum-LLM](architecture/adrs/002-stratum-llm.md) - Biblioteca unificada de providers LLM
Ver directorio [architecture/](architecture/) para documentación completa de arquitectura.
## Navegación ## Navegación
- [Volver a documentación raíz](../) - [Volver a documentación raíz](../)

View File

@ -0,0 +1,30 @@
# Arquitectura
Documentación de arquitectura del ecosistema STRATUMIOPS.
## Contenido
### ADRs (Architecture Decision Records)
Decisiones arquitecturales documentadas siguiendo el formato ADR:
- [**ADR-001: Stratum-Embeddings**](adrs/001-stratum-embeddings.md) - Biblioteca unificada de embeddings
- [**ADR-002: Stratum-LLM**](adrs/002-stratum-llm.md) - Biblioteca unificada de providers LLM
## Formato ADR
Cada ADR sigue la estructura:
| Sección | Descripción |
| ----------------- | ------------------------------------------ |
| Estado | Propuesto, Aceptado, Deprecado, Superseded |
| Contexto | Problema y estado actual |
| Decisión | Solución elegida |
| Justificación | Por qué esta solución |
| Consecuencias | Positivas, negativas, mitigaciones |
| Métricas de Éxito | Cómo medir el resultado |
## Navegación
- [Volver a documentación principal](../)
- [English version](../../en/architecture/)

View File

@ -0,0 +1,280 @@
# ADR-001: Stratum-Embeddings - Biblioteca Unificada de Embeddings
## Estado
**Propuesto**
## Contexto
### Estado Actual: Implementaciones Fragmentadas
El ecosistema tiene 3 implementaciones independientes de embeddings:
| Proyecto | Ubicación | Providers | Caching |
| ------------ | ------------------------------------- | ----------------------------- | ------- |
| Kogral | `kogral-core/src/embeddings/` | fastembed, rig-core (parcial) | No |
| Provisioning | `provisioning-rag/src/embeddings.rs` | OpenAI directo | No |
| Vapora | `vapora-llm-router/src/embeddings.rs` | OpenAI, HuggingFace, Ollama | No |
### Problemas Identificados
#### 1. Código Duplicado
Cada proyecto reimplementa:
- HTTP client para OpenAI embeddings
- Parsing de respuestas JSON
- Manejo de errores
- Token estimation
**Impacto**: ~400 líneas duplicadas, inconsistencias en manejo de errores.
#### 2. Sin Caching
Embeddings se regeneran cada vez:
```text
"What is Rust?" → OpenAI → 1536 dims → $0.00002
"What is Rust?" → OpenAI → 1536 dims → $0.00002 (mismo resultado)
"What is Rust?" → OpenAI → 1536 dims → $0.00002 (mismo resultado)
```
**Impacto**: Costos innecesarios, latencia adicional, rate limits más frecuentes.
#### 3. No Hay Fallback
Si OpenAI falla, todo falla. No hay fallback a alternativas locales (fastembed, Ollama).
**Impacto**: Disponibilidad reducida, dependencia total de un provider.
#### 4. Dimension Mismatch Silencioso
Diferentes providers producen diferentes dimensiones:
| Provider | Modelo | Dimensiones |
| --------- | ---------------------- | ----------- |
| fastembed | bge-small-en | 384 |
| fastembed | bge-large-en | 1024 |
| OpenAI | text-embedding-3-small | 1536 |
| OpenAI | text-embedding-3-large | 3072 |
| Ollama | nomic-embed-text | 768 |
**Impacto**: Índices vectoriales corruptos si se cambia de provider.
#### 5. Sin Métricas
No hay visibilidad de uso, hit rate de cache, latencia por provider, ni costos acumulados.
## Decisión
Crear `stratum-embeddings` como crate unificado que:
1. **Unifique** las implementaciones de Kogral, Provisioning, y Vapora
2. **Añada caching** para evitar re-computar embeddings idénticos
3. **Implemente fallback** entre providers (cloud → local)
4. **Documente claramente** las dimensiones y limitaciones por provider
5. **Exponga métricas** para observabilidad
6. **Provea VectorStore trait** con backends LanceDB y SurrealDB según necesidad del proyecto
### Decisión de Backend de Storage
Cada proyecto elige su backend de vector storage según su prioridad:
| Proyecto | Backend | Prioridad | Justificación |
| ------------ | --------- | ----------------- | -------------------------------------------------------- |
| Kogral | SurrealDB | Riqueza del grafo | Knowledge Graph necesita queries unificados graph+vector |
| Provisioning | LanceDB | Escala vectorial | RAG con millones de chunks documentales |
| Vapora | LanceDB | Escala vectorial | Traces de ejecución, pattern matching a escala |
#### Por qué SurrealDB para Kogral
Kogral es un Knowledge Graph donde las relaciones son el valor principal.
Con arquitectura híbrida (LanceDB vectores + SurrealDB graph), un query típico requeriría:
1. LanceDB: búsqueda vectorial → candidate_ids
2. SurrealDB: filtro de grafo sobre candidates → results
3. App layer: merge, re-rank, deduplicación
**Trade-off aceptado**: SurrealDB tiene peor rendimiento vectorial puro que LanceDB,
pero la escala de Kogral está limitada por curación humana del conocimiento
(10K-100K conceptos típicamente).
#### Por qué LanceDB para Provisioning y Vapora
| Aspecto | SurrealDB | LanceDB |
| --------------- | ---------- | -------------------- |
| Storage format | Row-based | Columnar (Lance) |
| Vector index | HNSW (RAM) | IVF-PQ (disk-native) |
| Escala práctica | Millones | Billones |
| Compresión | ~1x | ~32x (PQ) |
| Zero-copy read | No | Sí |
### Arquitectura
```text
┌─────────────────────────────────────────────────────────────────┐
│ stratum-embeddings │
├─────────────────────────────────────────────────────────────────┤
│ EmbeddingProvider trait │
│ ├─ embed(text) → Vec<f32>
│ ├─ embed_batch(texts) → Vec<Vec<f32>> │
│ ├─ dimensions() → usize │
│ └─ is_local() → bool │
│ │
│ ┌───────────┐ ┌───────────┐ ┌───────────┐ │
│ │ FastEmbed │ │ OpenAI │ │ Ollama │ │
│ │ (local) │ │ (cloud) │ │ (local) │ │
│ └───────────┘ └───────────┘ └───────────┘ │
│ └────────────┬────────────┘ │
│ ▼ │
│ EmbeddingCache (memory/disk) │
│ │ │
│ ▼ │
│ EmbeddingService │
│ │ │
│ ▼ │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ VectorStore trait │ │
│ │ ├─ upsert(id, embedding, metadata) │ │
│ │ ├─ search(embedding, limit, filter) → Vec<Match> │ │
│ │ └─ delete(id) │ │
│ └──────────────────────────────────────────────────────────┘ │
│ │ │ │
│ ▼ ▼ │
│ ┌─────────────────┐ ┌─────────────────┐ │
│ │ SurrealDbStore │ │ LanceDbStore │ │
│ │ (Kogral) │ │ (Prov/Vapora) │ │
│ └─────────────────┘ └─────────────────┘ │
└─────────────────────────────────────────────────────────────────┘
```
## Justificación
### Por Qué Caching es Crítico
Para un sistema RAG típico (10,000 chunks):
- **Sin cache**: Re-indexaciones y queries repetidas multiplican costos
- **Con cache**: Primera indexación paga, resto son cache hits
**Ahorro estimado**: 60-80% en costos de embeddings.
### Por Qué Fallback es Importante
| Escenario | Sin Fallback | Con Fallback |
| ----------------- | ------------ | ------------------- |
| OpenAI rate limit | ERROR | → fastembed (local) |
| OpenAI downtime | ERROR | → Ollama (local) |
| Sin internet | ERROR | → fastembed (local) |
### Por Qué Providers Locales Primero
Para desarrollo: fastembed carga modelo local (~100MB), no requiere API keys, sin costos, funciona offline.
Para producción: OpenAI para calidad, fastembed como fallback.
## Consecuencias
### Positivas
1. Single source of truth para todo el ecosistema
2. 60-80% menos llamadas a APIs de embeddings (caching)
3. Alta disponibilidad con providers locales (fallback)
4. Métricas de uso y costos
5. Feature-gated: solo compila lo necesario
6. Storage flexibility: VectorStore trait permite elegir backend por proyecto
### Negativas
1. **Dimension lock-in**: Cambiar provider requiere re-indexar
2. **Cache invalidation**: Contenido actualizado puede servir embeddings stale
3. **Model download**: fastembed descarga ~100MB en primer uso
4. **Storage lock-in por proyecto**: Kogral atado a SurrealDB, otros a LanceDB
### Mitigaciones
| Negativo | Mitigación |
| ----------------- | ------------------------------------------------------ |
| Dimension lock-in | Documentar claramente, warn en cambio de provider |
| Cache stale | TTL configurable, opción de bypass |
| Model download | Mostrar progreso, cache en ~/.cache/fastembed |
| Storage lock-in | Decisión consciente basada en prioridades del proyecto |
## Métricas de Éxito
| Métrica | Actual | Objetivo |
| --------------------------- | ------ | -------- |
| Implementaciones duplicadas | 3 | 1 |
| Cache hit rate | 0% | >60% |
| Fallback availability | 0% | 100% |
| Cost per 10K embeddings | ~$0.20 | ~$0.05 |
## Guía de Selección de Provider
### Desarrollo
```rust
// Local, gratis, offline
let service = EmbeddingService::builder()
.with_provider(FastEmbedProvider::small()?) // 384 dims
.with_memory_cache()
.build()?;
```
### Producción (Calidad)
```rust
// OpenAI con fallback local
let service = EmbeddingService::builder()
.with_provider(OpenAiEmbeddingProvider::large()?) // 3072 dims
.with_provider(FastEmbedProvider::large()?) // Fallback
.with_memory_cache()
.build()?;
```
### Producción (Costo-Optimizado)
```rust
// OpenAI small con fallback
let service = EmbeddingService::builder()
.with_provider(OpenAiEmbeddingProvider::small()?) // 1536 dims
.with_provider(OllamaEmbeddingProvider::nomic()) // Fallback
.with_memory_cache()
.build()?;
```
## Matriz de Compatibilidad de Dimensiones
| Si usas... | Puedes cambiar a... | NO puedes cambiar a... |
| ---------------------- | --------------------------- | ---------------------- |
| fastembed small (384) | fastembed small, all-minilm | Cualquier otro |
| fastembed large (1024) | fastembed large | Cualquier otro |
| OpenAI small (1536) | OpenAI small, ada-002 | Cualquier otro |
| OpenAI large (3072) | OpenAI large | Cualquier otro |
**Regla**: Solo puedes cambiar entre modelos con las MISMAS dimensiones.
## Prioridad de Implementación
| Orden | Feature | Razón |
| ----- | ----------------------- | ---------------------------- |
| 1 | EmbeddingProvider trait | Base para todo |
| 2 | FastEmbed provider | Funciona sin API keys |
| 3 | Memory cache | Mayor impacto en costos |
| 4 | VectorStore trait | Abstracción de storage |
| 5 | SurrealDbStore | Kogral necesita graph+vector |
| 6 | LanceDbStore | Provisioning/Vapora escala |
| 7 | OpenAI provider | Producción |
| 8 | Ollama provider | Fallback local |
| 9 | Batch processing | Eficiencia |
| 10 | Metrics | Observabilidad |
## Referencias
**Implementaciones Existentes**:
- Kogral: `kogral-core/src/embeddings/`
- Vapora: `vapora-llm-router/src/embeddings.rs`
- Provisioning: `provisioning/platform/crates/rag/src/embeddings.rs`
**Ubicación Objetivo**: `stratumiops/crates/stratum-embeddings/`

View File

@ -0,0 +1,279 @@
# ADR-002: Stratum-LLM - Biblioteca Unificada de Providers LLM
## Estado
**Propuesto**
## Contexto
### Estado Actual: Conexiones LLM Fragmentadas
El ecosistema stratumiops tiene 4 proyectos con funcionalidad IA, cada uno con su propia implementación:
| Proyecto | Implementación | Providers | Duplicación |
| ------------ | -------------------------- | ---------------------- | --------------------- |
| Vapora | `typedialog-ai` (path dep) | Claude, OpenAI, Ollama | Base compartida |
| TypeDialog | `typedialog-ai` (local) | Claude, OpenAI, Ollama | Define la abstracción |
| Provisioning | Custom `LlmClient` | Claude, OpenAI | 100% duplicado |
| Kogral | `rig-core` | Solo embeddings | Diferente stack |
### Problemas Identificados
#### 1. Duplicación de Código
Provisioning reimplementa lo que TypeDialog ya tiene:
- reqwest HTTP client
- Headers: x-api-key, anthropic-version
- JSON body formatting
- Response parsing
- Error handling
**Impacto**: ~500 líneas duplicadas, bugs arreglados en un lugar no se propagan.
#### 2. Solo API Keys, No CLI Detection
Ningún proyecto detecta credenciales de CLIs oficiales:
```text
Claude CLI: ~/.config/claude/credentials.json
OpenAI CLI: ~/.config/openai/credentials.json
```
**Impacto**: Usuarios con Claude Pro/Max ($20-100/mes) pagan API tokens cuando podrían usar su suscripción.
#### 3. Sin Fallback Automático
Cuando un provider falla (rate limit, timeout), la request falla completamente:
```text
Actual: Request → Claude API → Rate Limit → ERROR
Deseado: Request → Claude API → Rate Limit → OpenAI → Success
```
#### 4. Sin Circuit Breaker
Si Claude API está caído, cada request intenta conectar, falla, y propaga el error:
```text
Request 1 → Claude → Timeout (30s) → Error
Request 2 → Claude → Timeout (30s) → Error
Request 3 → Claude → Timeout (30s) → Error
```
**Impacto**: Latencia acumulada, UX degradado.
#### 5. Sin Caching
Requests idénticas van siempre a la API:
```text
"Explain this Rust error" → Claude → $0.003
"Explain this Rust error" → Claude → $0.003 (mismo resultado)
```
**Impacto**: Costos innecesarios, especialmente en desarrollo/testing.
#### 6. Kogral No Integrado
Kogral tiene guidelines y patterns que podrían enriquecer el contexto de LLM, pero no hay integración.
## Decisión
Crear `stratum-llm` como crate unificado que:
1. **Consolide** las implementaciones existentes de typedialog-ai y provisioning
2. **Detecte** credenciales CLI y subscripciones antes de usar API keys
3. **Implemente** fallback automático con circuit breaker
4. **Añada** caching de requests para reducir costos
5. **Integre** Kogral para enriquecer contexto
6. **Sea usado** por todos los proyectos del ecosistema
### Arquitectura
```text
┌─────────────────────────────────────────────────────────┐
│ stratum-llm │
├─────────────────────────────────────────────────────────┤
│ CredentialDetector │
│ ├─ Claude CLI → ~/.config/claude/ (subscription) │
│ ├─ OpenAI CLI → ~/.config/openai/ │
│ ├─ Env vars → *_API_KEY │
│ └─ Ollama → localhost:11434 (free) │
│ │ │
│ ▼ │
│ ProviderChain (ordered by priority) │
│ [CLI/Sub] → [API] → [DeepSeek] → [Ollama] │
│ │ │ │ │ │
│ └──────────┴─────────┴───────────┘ │
│ │ │
│ CircuitBreaker per provider │
│ │ │
│ RequestCache │
│ │ │
│ KogralIntegration │
│ │ │
│ UnifiedClient │
│ │
└─────────────────────────────────────────────────────────┘
```
## Justificación
### Por Qué No Usar Otra Crate Externa
| Alternativa | Por Qué No |
| -------------- | ------------------------------------------ |
| kaccy-ai | Orientada a blockchain/fraud detection |
| llm (crate) | Muy básica, sin circuit breaker ni caching |
| langchain-rust | Port de Python, no idiomático Rust |
| rig-core | Solo embeddings/RAG, no chat completion |
**Mejor opción**: Construir sobre typedialog-ai y añadir features faltantes.
### Por Qué CLI Detection es Importante
Análisis de costos para usuario típico:
| Escenario | Costo Mensual |
| ------------------------- | -------------------- |
| Solo API (actual) | ~$840 |
| Claude Pro + API overflow | ~$20 + ~$200 = $220 |
| Claude Max + API overflow | ~$100 + ~$50 = $150 |
**Ahorro potencial**: 70-80% detectando y usando subscripciones primero.
### Por Qué Circuit Breaker
Sin circuit breaker, un provider caído causa:
- N requests × 30s timeout = N×30s de latencia total
- Todos los recursos ocupados esperando timeouts
Con circuit breaker:
- Primera falla abre circuito
- Siguientes requests fallan inmediatamente (fast fail)
- Fallback a otro provider sin esperar
- Circuito se resetea después de cooldown
### Por Qué Caching
Para desarrollo típico:
- Mismas preguntas repetidas mientras se itera
- Testing ejecuta mismos prompts múltiples veces
Cache hit rate estimado: 15-30% en desarrollo activo.
### Por Qué Kogral Integration
Kogral tiene guidelines por lenguaje, patterns por dominio, y ADRs.
Sin integración el LLM genera código genérico;
con integración genera código que sigue convenciones del proyecto.
## Consecuencias
### Positivas
1. Single source of truth para lógica de LLM
2. CLI detection reduce costos 70-80%
3. Circuit breaker + fallback = alta disponibilidad
4. 15-30% menos requests en desarrollo (caching)
5. Kogral mejora calidad de generación
6. Feature-gated: cada feature es opcional
### Negativas
1. **Migration effort**: Refactorizar Vapora, TypeDialog, Provisioning
2. **New dependency**: Proyectos dependen de stratumiops
3. **CLI auth complexity**: Diferentes formatos de credenciales por versión
4. **Cache invalidation**: Respuestas obsoletas si no se gestiona bien
### Mitigaciones
| Negativo | Mitigación |
| ------------------- | -------------------------------------------- |
| Migration effort | Re-export API compatible desde typedialog-ai |
| New dependency | Path dependency local, no crates.io |
| CLI auth complexity | Version detection, fallback a API si falla |
| Cache invalidation | TTL configurable, opción de bypass |
## Métricas de Éxito
| Métrica | Actual | Objetivo |
| --------------------------- | ------ | --------------- |
| Líneas de código duplicadas | ~500 | 0 |
| CLI credential detection | 0% | 100% |
| Fallback success rate | 0% | >90% |
| Cache hit rate | 0% | 15-30% |
| Latency (provider down) | 30s+ | <1s (fast fail) |
## Análisis de Impacto en Costos
Basado en datos reales de uso ($840/mes):
| Escenario | Ahorro |
| ----------------------------- | ---------------- |
| CLI detection (Claude Max) | ~$700/mes |
| Caching (15% hit rate) | ~$50/mes |
| DeepSeek fallback para código | ~$100/mes |
| **Total potencial** | **$500-700/mes** |
## Estrategia de Migración
### Fases de Migración
1. Crear stratum-llm con API compatible con typedialog-ai
2. typedialog-ai re-exporta stratum-llm (backward compatible)
3. Vapora migra a stratum-llm directamente
4. Provisioning migra su LlmClient a stratum-llm
5. Deprecar typedialog-ai, consolidar en stratum-llm
### Adopción de Features
| Feature | Adopción |
| --------------- | --------------------------------------- |
| Basic providers | Inmediata (reemplazo directo) |
| CLI detection | Opcional, feature flag |
| Circuit breaker | Default on |
| Caching | Default on, configurable TTL |
| Kogral | Feature flag, requiere Kogral instalado |
## Alternativas Consideradas
### Alternativa 1: Mejorar typedialog-ai In-Place
**Pros**: No requiere nuevo crate
**Cons**: TypeDialog es proyecto específico, no infraestructura compartida
**Decisión**: stratum-llm en stratumiops es mejor ubicación para infraestructura cross-project.
### Alternativa 2: Usar LiteLLM (Python) como Proxy
**Pros**: Muy completo, 100+ providers
**Cons**: Dependencia Python, latencia de proxy, no Rust-native
**Decisión**: Mantener stack Rust puro.
### Alternativa 3: Cada Proyecto Mantiene su Implementación
**Pros**: Independencia
**Cons**: Duplicación, inconsistencia, bugs no compartidos
**Decisión**: Consolidar es mejor a largo plazo.
## Referencias
**Implementaciones Existentes**:
- TypeDialog: `typedialog/crates/typedialog-ai/`
- Vapora: `vapora/crates/vapora-llm-router/`
- Provisioning: `provisioning/platform/crates/rag/`
**Kogral**: `kogral/`
**Ubicación Objetivo**: `stratumiops/crates/stratum-llm/`

View File

@ -0,0 +1,22 @@
# ADRs - Architecture Decision Records
Registro de decisiones arquitecturales del ecosistema STRATUMIOPS.
## ADRs Activos
| ID | Título | Estado |
| -------------------------------- | ------------------------------------------------------ | --------- |
| [001](001-stratum-embeddings.md) | Stratum-Embeddings: Biblioteca Unificada de Embeddings | Propuesto |
| [002](002-stratum-llm.md) | Stratum-LLM: Biblioteca Unificada de Providers LLM | Propuesto |
## Estados
- **Propuesto**: En revisión, pendiente de implementación
- **Aceptado**: Aprobado e implementado
- **Deprecado**: Reemplazado por otro ADR
- **Superseded**: Obsoleto, ver ADR de reemplazo
## Navegación
- [Volver a arquitectura](../)
- [English version](../../../en/architecture/adrs/)