763 lines
24 KiB
Rust
763 lines
24 KiB
Rust
// RLM Engine - Core Orchestration
|
|
// Coordinates chunking, storage, hybrid search, and LLM dispatch
|
|
|
|
use std::sync::Arc;
|
|
use std::time::Instant;
|
|
|
|
use tracing::{debug, info, warn};
|
|
use vapora_llm_router::providers::LLMClient;
|
|
|
|
use crate::chunking::{create_chunker, ChunkingConfig};
|
|
use crate::dispatch::{AggregatedResult, LLMDispatcher};
|
|
use crate::embeddings::{EmbeddingConfig, EmbeddingGenerator};
|
|
use crate::metrics::{CHUNKS_TOTAL, QUERY_DURATION};
|
|
use crate::search::bm25::BM25Index;
|
|
use crate::search::hybrid::{HybridSearch, ScoredChunk};
|
|
use crate::storage::{Chunk, Storage};
|
|
use crate::RLMError;
|
|
|
|
/// RLM Engine configuration
|
|
#[derive(Debug, Clone)]
|
|
pub struct RLMEngineConfig {
|
|
/// Default chunking configuration
|
|
pub chunking: ChunkingConfig,
|
|
/// Embedding configuration (optional - if None, no embeddings generated)
|
|
pub embedding: Option<EmbeddingConfig>,
|
|
/// Enable automatic BM25 index rebuilds
|
|
pub auto_rebuild_bm25: bool,
|
|
/// Maximum chunks per document (safety limit)
|
|
pub max_chunks_per_doc: usize,
|
|
}
|
|
|
|
impl Default for RLMEngineConfig {
|
|
fn default() -> Self {
|
|
Self {
|
|
chunking: ChunkingConfig::default(),
|
|
embedding: Some(EmbeddingConfig::default()), // Enable embeddings by default
|
|
auto_rebuild_bm25: true,
|
|
max_chunks_per_doc: 10_000,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// RLM Engine - orchestrates chunking, storage, and hybrid search
|
|
pub struct RLMEngine<S: Storage> {
|
|
storage: Arc<S>,
|
|
bm25_index: Arc<BM25Index>,
|
|
hybrid_search: HybridSearch<S>,
|
|
embedding_generator: Option<Arc<EmbeddingGenerator>>,
|
|
dispatcher: Arc<LLMDispatcher>,
|
|
config: RLMEngineConfig,
|
|
}
|
|
|
|
impl<S: Storage> RLMEngine<S> {
|
|
/// Create a new RLM engine
|
|
pub fn new(storage: Arc<S>, bm25_index: Arc<BM25Index>) -> crate::Result<Self> {
|
|
let hybrid_search = HybridSearch::new(storage.clone(), bm25_index.clone())?;
|
|
let config = RLMEngineConfig::default();
|
|
|
|
let embedding_generator = config
|
|
.embedding
|
|
.as_ref()
|
|
.map(|cfg| Arc::new(EmbeddingGenerator::new(cfg.clone())));
|
|
|
|
// Phase 6: No LLM client configured by default
|
|
let dispatcher = Arc::new(LLMDispatcher::new(None));
|
|
|
|
Ok(Self {
|
|
storage,
|
|
bm25_index,
|
|
hybrid_search,
|
|
embedding_generator,
|
|
dispatcher,
|
|
config,
|
|
})
|
|
}
|
|
|
|
/// Create with custom configuration
|
|
pub fn with_config(
|
|
storage: Arc<S>,
|
|
bm25_index: Arc<BM25Index>,
|
|
config: RLMEngineConfig,
|
|
) -> crate::Result<Self> {
|
|
let hybrid_search = HybridSearch::new(storage.clone(), bm25_index.clone())?;
|
|
|
|
let embedding_generator = config
|
|
.embedding
|
|
.as_ref()
|
|
.map(|cfg| Arc::new(EmbeddingGenerator::new(cfg.clone())));
|
|
|
|
// Phase 6: No LLM client configured by default
|
|
let dispatcher = Arc::new(LLMDispatcher::new(None));
|
|
|
|
Ok(Self {
|
|
storage,
|
|
bm25_index,
|
|
hybrid_search,
|
|
embedding_generator,
|
|
dispatcher,
|
|
config,
|
|
})
|
|
}
|
|
|
|
/// Create with LLM client for production use
|
|
pub fn with_llm_client(
|
|
storage: Arc<S>,
|
|
bm25_index: Arc<BM25Index>,
|
|
llm_client: Arc<dyn LLMClient + Send + Sync>,
|
|
config: Option<RLMEngineConfig>,
|
|
) -> crate::Result<Self> {
|
|
let config = config.unwrap_or_default();
|
|
let hybrid_search = HybridSearch::new(storage.clone(), bm25_index.clone())?;
|
|
|
|
let embedding_generator = config
|
|
.embedding
|
|
.as_ref()
|
|
.map(|cfg| Arc::new(EmbeddingGenerator::new(cfg.clone())));
|
|
|
|
// Production: LLM client configured
|
|
let dispatcher = Arc::new(LLMDispatcher::new(Some(llm_client)));
|
|
|
|
Ok(Self {
|
|
storage,
|
|
bm25_index,
|
|
hybrid_search,
|
|
embedding_generator,
|
|
dispatcher,
|
|
config,
|
|
})
|
|
}
|
|
|
|
/// Load a document: chunk → embed (placeholder) → persist → index
|
|
///
|
|
/// # Arguments
|
|
/// - `doc_id`: Unique document identifier
|
|
/// - `content`: Document content to chunk
|
|
/// - `chunking_config`: Optional chunking configuration (uses default if
|
|
/// None)
|
|
///
|
|
/// # Returns
|
|
/// Number of chunks created
|
|
pub async fn load_document(
|
|
&self,
|
|
doc_id: &str,
|
|
content: &str,
|
|
chunking_config: Option<ChunkingConfig>,
|
|
) -> crate::Result<usize> {
|
|
let start = Instant::now();
|
|
info!("Loading document: {}", doc_id);
|
|
|
|
// Use provided config or default
|
|
let config = chunking_config.unwrap_or_else(|| self.config.chunking.clone());
|
|
|
|
// Create chunker and chunk content
|
|
let chunker = create_chunker(&config);
|
|
let chunk_results = chunker.chunk(content)?;
|
|
|
|
// Safety check
|
|
if chunk_results.len() > self.config.max_chunks_per_doc {
|
|
warn!(
|
|
"Document {} has {} chunks, exceeds max {}",
|
|
doc_id,
|
|
chunk_results.len(),
|
|
self.config.max_chunks_per_doc
|
|
);
|
|
return Err(RLMError::ChunkingError(format!(
|
|
"Document exceeds max chunks: {} > {}",
|
|
chunk_results.len(),
|
|
self.config.max_chunks_per_doc
|
|
)));
|
|
}
|
|
|
|
debug!(
|
|
"Chunked document {} into {} chunks using {:?} strategy",
|
|
doc_id,
|
|
chunk_results.len(),
|
|
config.strategy
|
|
);
|
|
|
|
// Generate embeddings if enabled
|
|
let embeddings = if let Some(ref generator) = self.embedding_generator {
|
|
debug!("Generating embeddings for {} chunks", chunk_results.len());
|
|
let texts: Vec<String> = chunk_results.iter().map(|c| c.content.clone()).collect();
|
|
Some(generator.embed_batch(&texts).await?)
|
|
} else {
|
|
debug!("Embedding generation disabled");
|
|
None
|
|
};
|
|
|
|
// Convert ChunkResult to Chunk and persist
|
|
let mut chunks = Vec::new();
|
|
for (idx, chunk_result) in chunk_results.iter().enumerate() {
|
|
let chunk_id = format!("{}-chunk-{}", doc_id, idx);
|
|
|
|
// Get embedding for this chunk (if generated)
|
|
let embedding = embeddings.as_ref().and_then(|embs| embs.get(idx)).cloned();
|
|
|
|
let chunk = Chunk {
|
|
chunk_id: chunk_id.clone(),
|
|
doc_id: doc_id.to_string(),
|
|
content: chunk_result.content.clone(),
|
|
embedding, // Phase 5: Real embeddings from multi-provider
|
|
start_idx: chunk_result.start_idx,
|
|
end_idx: chunk_result.end_idx,
|
|
metadata: None,
|
|
created_at: chrono::Utc::now().to_rfc3339(),
|
|
};
|
|
|
|
// Save to storage
|
|
self.storage.save_chunk(chunk.clone()).await?;
|
|
|
|
// Add to BM25 index
|
|
self.bm25_index.add_document(&chunk)?;
|
|
|
|
chunks.push(chunk);
|
|
}
|
|
|
|
// Commit BM25 index
|
|
self.bm25_index.commit()?;
|
|
|
|
// Update metrics
|
|
CHUNKS_TOTAL
|
|
.with_label_values(&[&format!("{:?}", config.strategy)])
|
|
.inc_by(chunks.len() as u64);
|
|
|
|
let duration = start.elapsed();
|
|
info!(
|
|
"Loaded document {} with {} chunks in {:?}",
|
|
doc_id,
|
|
chunks.len(),
|
|
duration
|
|
);
|
|
|
|
Ok(chunks.len())
|
|
}
|
|
|
|
/// Query with hybrid search (semantic + BM25 + RRF fusion)
|
|
///
|
|
/// # Arguments
|
|
/// - `doc_id`: Document to search within
|
|
/// - `query_text`: Keyword query for BM25
|
|
/// - `query_embedding`: Optional vector embedding for semantic search
|
|
/// - `limit`: Maximum results to return
|
|
///
|
|
/// # Returns
|
|
/// Scored chunks ranked by hybrid search
|
|
pub async fn query(
|
|
&self,
|
|
doc_id: &str,
|
|
query_text: &str,
|
|
query_embedding: Option<&[f32]>,
|
|
limit: usize,
|
|
) -> crate::Result<Vec<ScoredChunk>> {
|
|
let start = Instant::now();
|
|
|
|
let results = if let Some(embedding) = query_embedding {
|
|
// Full hybrid search: BM25 + semantic + RRF
|
|
debug!(
|
|
"Hybrid query: doc={}, query='{}', limit={}",
|
|
doc_id, query_text, limit
|
|
);
|
|
self.hybrid_search
|
|
.search(doc_id, query_text, embedding, limit)
|
|
.await?
|
|
} else {
|
|
// BM25-only search (no embedding provided)
|
|
debug!(
|
|
"BM25-only query: doc={}, query='{}', limit={}",
|
|
doc_id, query_text, limit
|
|
);
|
|
let bm25_results = self.hybrid_search.bm25_search(query_text, limit)?;
|
|
|
|
// Get chunks from storage
|
|
let all_chunks = self.storage.get_chunks(doc_id).await?;
|
|
|
|
// Map BM25 results to ScoredChunk
|
|
bm25_results
|
|
.into_iter()
|
|
.filter_map(|bm25_result| {
|
|
all_chunks
|
|
.iter()
|
|
.find(|c| c.chunk_id == bm25_result.chunk_id)
|
|
.map(|chunk| ScoredChunk {
|
|
chunk: chunk.clone(),
|
|
score: bm25_result.score,
|
|
bm25_score: Some(bm25_result.score),
|
|
semantic_score: None,
|
|
})
|
|
})
|
|
.collect()
|
|
};
|
|
|
|
let duration = start.elapsed();
|
|
QUERY_DURATION
|
|
.with_label_values(&[if query_embedding.is_some() {
|
|
"hybrid"
|
|
} else {
|
|
"bm25_only"
|
|
}])
|
|
.observe(duration.as_secs_f64());
|
|
|
|
debug!("Query returned {} results in {:?}", results.len(), duration);
|
|
|
|
Ok(results)
|
|
}
|
|
|
|
/// Dispatch subtask to LLM for distributed reasoning
|
|
///
|
|
/// # Arguments
|
|
/// - `doc_id`: Document to query
|
|
/// - `query_text`: Query/task description
|
|
/// - `query_embedding`: Optional embedding for hybrid search
|
|
/// - `limit`: Max chunks to retrieve
|
|
///
|
|
/// # Returns
|
|
/// Aggregated result from LLM analysis of relevant chunks
|
|
pub async fn dispatch_subtask(
|
|
&self,
|
|
doc_id: &str,
|
|
query_text: &str,
|
|
query_embedding: Option<&[f32]>,
|
|
limit: usize,
|
|
) -> crate::Result<AggregatedResult> {
|
|
info!("Dispatching subtask: doc={}, query={}", doc_id, query_text);
|
|
|
|
// Step 1: Retrieve relevant chunks via hybrid search
|
|
let chunks = self
|
|
.query(doc_id, query_text, query_embedding, limit)
|
|
.await?;
|
|
|
|
debug!("Retrieved {} chunks for dispatch", chunks.len());
|
|
|
|
// Step 2: Dispatch to LLM
|
|
let result = self.dispatcher.dispatch(query_text, &chunks).await?;
|
|
|
|
info!(
|
|
"Dispatch completed: {} LLM calls, {} total tokens",
|
|
result.num_calls,
|
|
result.total_input_tokens + result.total_output_tokens
|
|
);
|
|
|
|
Ok(result)
|
|
}
|
|
|
|
/// Get BM25 index statistics
|
|
pub fn index_stats(&self) -> crate::search::bm25::IndexStats {
|
|
self.bm25_index.stats()
|
|
}
|
|
|
|
/// Rebuild BM25 index from all chunks for a document
|
|
pub async fn rebuild_index(&self, doc_id: &str) -> crate::Result<()> {
|
|
info!("Rebuilding BM25 index for document: {}", doc_id);
|
|
let chunks = self.storage.get_chunks(doc_id).await?;
|
|
self.bm25_index.rebuild_from_chunks(&chunks)?;
|
|
info!(
|
|
"Rebuilt BM25 index for {} with {} chunks",
|
|
doc_id,
|
|
chunks.len()
|
|
);
|
|
Ok(())
|
|
}
|
|
|
|
/// Delete all chunks for a document
|
|
pub async fn delete_document(&self, doc_id: &str) -> crate::Result<u64> {
|
|
info!("Deleting document: {}", doc_id);
|
|
let deleted_count = self.storage.delete_chunks(doc_id).await?;
|
|
|
|
// Rebuild BM25 index to remove deleted chunks
|
|
if self.config.auto_rebuild_bm25 {
|
|
// For now, we can't selectively delete from BM25, so we'd need to rebuild
|
|
// For Phase 3, we'll just warn - full rebuild happens on next load
|
|
warn!(
|
|
"BM25 index may contain stale entries for deleted doc {}. Rebuild recommended.",
|
|
doc_id
|
|
);
|
|
}
|
|
|
|
Ok(deleted_count)
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use std::collections::HashMap;
|
|
use std::sync::Mutex;
|
|
|
|
use async_trait::async_trait;
|
|
|
|
use super::*;
|
|
use crate::chunking::ChunkingStrategy;
|
|
use crate::storage::{Buffer, ExecutionHistory};
|
|
|
|
// Mock storage for testing
|
|
struct MockStorage {
|
|
chunks: Arc<Mutex<HashMap<String, Vec<Chunk>>>>,
|
|
}
|
|
|
|
impl MockStorage {
|
|
fn new() -> Self {
|
|
Self {
|
|
chunks: Arc::new(Mutex::new(HashMap::new())),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl Storage for MockStorage {
|
|
async fn save_chunk(&self, chunk: Chunk) -> crate::Result<()> {
|
|
let mut chunks = self.chunks.lock().unwrap();
|
|
chunks.entry(chunk.doc_id.clone()).or_default().push(chunk);
|
|
Ok(())
|
|
}
|
|
|
|
async fn get_chunks(&self, doc_id: &str) -> crate::Result<Vec<Chunk>> {
|
|
let chunks = self.chunks.lock().unwrap();
|
|
Ok(chunks.get(doc_id).cloned().unwrap_or_default())
|
|
}
|
|
|
|
async fn get_chunk(&self, chunk_id: &str) -> crate::Result<Option<Chunk>> {
|
|
let chunks = self.chunks.lock().unwrap();
|
|
for chunk_list in chunks.values() {
|
|
if let Some(chunk) = chunk_list.iter().find(|c| c.chunk_id == chunk_id) {
|
|
return Ok(Some(chunk.clone()));
|
|
}
|
|
}
|
|
Ok(None)
|
|
}
|
|
|
|
async fn search_by_embedding(
|
|
&self,
|
|
_embedding: &[f32],
|
|
_limit: usize,
|
|
) -> crate::Result<Vec<Chunk>> {
|
|
Ok(Vec::new())
|
|
}
|
|
|
|
async fn save_buffer(&self, _buffer: Buffer) -> crate::Result<()> {
|
|
Ok(())
|
|
}
|
|
|
|
async fn get_buffer(&self, _buffer_id: &str) -> crate::Result<Option<Buffer>> {
|
|
Ok(None)
|
|
}
|
|
|
|
async fn cleanup_expired_buffers(&self) -> crate::Result<u64> {
|
|
Ok(0)
|
|
}
|
|
|
|
async fn save_execution(&self, _execution: ExecutionHistory) -> crate::Result<()> {
|
|
Ok(())
|
|
}
|
|
|
|
async fn get_executions(
|
|
&self,
|
|
_doc_id: &str,
|
|
_limit: usize,
|
|
) -> crate::Result<Vec<ExecutionHistory>> {
|
|
Ok(Vec::new())
|
|
}
|
|
|
|
async fn delete_chunks(&self, doc_id: &str) -> crate::Result<u64> {
|
|
let mut chunks = self.chunks.lock().unwrap();
|
|
let count = chunks.remove(doc_id).map(|v| v.len()).unwrap_or(0);
|
|
Ok(count as u64)
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_engine_creation() {
|
|
let storage = Arc::new(MockStorage::new());
|
|
let bm25_index = Arc::new(BM25Index::new().unwrap());
|
|
|
|
let engine = RLMEngine::new(storage, bm25_index);
|
|
assert!(engine.is_ok());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_load_document_fixed_chunking() {
|
|
let storage = Arc::new(MockStorage::new());
|
|
let bm25_index = Arc::new(BM25Index::new().unwrap());
|
|
let engine = RLMEngine::new(storage.clone(), bm25_index).unwrap();
|
|
|
|
let content = "a".repeat(250); // 250 chars
|
|
let config = ChunkingConfig {
|
|
strategy: ChunkingStrategy::Fixed,
|
|
chunk_size: 100,
|
|
overlap: 20,
|
|
};
|
|
|
|
let chunk_count = engine
|
|
.load_document("doc-1", &content, Some(config))
|
|
.await
|
|
.unwrap();
|
|
|
|
assert!(chunk_count >= 2, "Should create at least 2 chunks");
|
|
|
|
// Verify chunks are persisted
|
|
let chunks = storage.get_chunks("doc-1").await.unwrap();
|
|
assert_eq!(chunks.len(), chunk_count);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_load_document_semantic_chunking() {
|
|
let storage = Arc::new(MockStorage::new());
|
|
let bm25_index = Arc::new(BM25Index::new().unwrap());
|
|
let engine = RLMEngine::new(storage.clone(), bm25_index).unwrap();
|
|
|
|
let content = "First sentence. Second sentence! Third sentence?";
|
|
let config = ChunkingConfig {
|
|
strategy: ChunkingStrategy::Semantic,
|
|
chunk_size: 50,
|
|
overlap: 10,
|
|
};
|
|
|
|
let chunk_count = engine
|
|
.load_document("doc-2", content, Some(config))
|
|
.await
|
|
.unwrap();
|
|
|
|
assert!(chunk_count > 0, "Should create at least 1 chunk");
|
|
|
|
// Verify chunks are persisted
|
|
let chunks = storage.get_chunks("doc-2").await.unwrap();
|
|
assert_eq!(chunks.len(), chunk_count);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_query_bm25_only() {
|
|
let storage = Arc::new(MockStorage::new());
|
|
let bm25_index = Arc::new(BM25Index::new().unwrap());
|
|
let engine = RLMEngine::new(storage.clone(), bm25_index).unwrap();
|
|
|
|
// Load document
|
|
let content =
|
|
"Rust programming language. Python programming tutorial. Rust async patterns.";
|
|
engine.load_document("doc-3", content, None).await.unwrap();
|
|
|
|
// Query (BM25-only, no embedding)
|
|
let results = engine.query("doc-3", "Rust", None, 5).await.unwrap();
|
|
|
|
assert!(!results.is_empty(), "Should find results for 'Rust'");
|
|
assert!(results[0].bm25_score.is_some(), "Should have BM25 score");
|
|
assert!(
|
|
results[0].semantic_score.is_none(),
|
|
"Should not have semantic score"
|
|
);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_query_hybrid_search() {
|
|
let storage = Arc::new(MockStorage::new());
|
|
let bm25_index = Arc::new(BM25Index::new().unwrap());
|
|
let engine = RLMEngine::new(storage.clone(), bm25_index).unwrap();
|
|
|
|
// Load document with manual chunk creation (to add embeddings)
|
|
let chunk = Chunk {
|
|
chunk_id: "doc-4-chunk-0".to_string(),
|
|
doc_id: "doc-4".to_string(),
|
|
content: "Rust programming language".to_string(),
|
|
embedding: Some(vec![1.0, 0.0, 0.0]),
|
|
start_idx: 0,
|
|
end_idx: 26,
|
|
metadata: None,
|
|
created_at: chrono::Utc::now().to_rfc3339(),
|
|
};
|
|
|
|
storage.save_chunk(chunk.clone()).await.unwrap();
|
|
engine.bm25_index.add_document(&chunk).unwrap();
|
|
engine.bm25_index.commit().unwrap();
|
|
|
|
// Query with embedding (hybrid search)
|
|
let query_embedding = vec![0.9, 0.1, 0.0];
|
|
let results = engine
|
|
.query("doc-4", "Rust", Some(&query_embedding), 5)
|
|
.await
|
|
.unwrap();
|
|
|
|
assert!(!results.is_empty(), "Should find results");
|
|
// In hybrid search, we should have both scores (if RRF found matches in both)
|
|
// But with only 1 chunk, we might only get BM25 or semantic
|
|
assert!(
|
|
results[0].bm25_score.is_some() || results[0].semantic_score.is_some(),
|
|
"Should have at least one score"
|
|
);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_delete_document() {
|
|
let storage = Arc::new(MockStorage::new());
|
|
let bm25_index = Arc::new(BM25Index::new().unwrap());
|
|
let engine = RLMEngine::new(storage.clone(), bm25_index).unwrap();
|
|
|
|
// Load document
|
|
engine
|
|
.load_document("doc-5", "Test content", None)
|
|
.await
|
|
.unwrap();
|
|
|
|
// Verify it exists
|
|
let chunks_before = storage.get_chunks("doc-5").await.unwrap();
|
|
assert!(!chunks_before.is_empty());
|
|
|
|
// Delete
|
|
let deleted = engine.delete_document("doc-5").await.unwrap();
|
|
assert_eq!(deleted, chunks_before.len() as u64);
|
|
|
|
// Verify deletion
|
|
let chunks_after = storage.get_chunks("doc-5").await.unwrap();
|
|
assert!(chunks_after.is_empty());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_max_chunks_safety_limit() {
|
|
let storage = Arc::new(MockStorage::new());
|
|
let bm25_index = Arc::new(BM25Index::new().unwrap());
|
|
|
|
let config = RLMEngineConfig {
|
|
max_chunks_per_doc: 5, // Very low limit for testing
|
|
..Default::default()
|
|
};
|
|
|
|
let engine = RLMEngine::with_config(storage, bm25_index, config).unwrap();
|
|
|
|
// Create content that will exceed limit
|
|
let content = "a".repeat(1000); // Will create many small chunks
|
|
let chunking_config = ChunkingConfig {
|
|
strategy: ChunkingStrategy::Fixed,
|
|
chunk_size: 10,
|
|
overlap: 0,
|
|
};
|
|
|
|
let result = engine
|
|
.load_document("doc-6", &content, Some(chunking_config))
|
|
.await;
|
|
|
|
assert!(
|
|
result.is_err(),
|
|
"Should fail when exceeding max chunks limit"
|
|
);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_index_stats() {
|
|
let storage = Arc::new(MockStorage::new());
|
|
let bm25_index = Arc::new(BM25Index::new().unwrap());
|
|
let engine = RLMEngine::new(storage, bm25_index).unwrap();
|
|
|
|
// Initially empty
|
|
let stats = engine.index_stats();
|
|
assert_eq!(stats.num_docs, 0);
|
|
|
|
// Load document
|
|
engine
|
|
.load_document("doc-7", "Test content", None)
|
|
.await
|
|
.unwrap();
|
|
|
|
// Check stats again
|
|
let stats = engine.index_stats();
|
|
assert!(stats.num_docs > 0);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_embeddings_generated() {
|
|
use crate::embeddings::EmbeddingConfig;
|
|
|
|
let storage = Arc::new(MockStorage::new());
|
|
let bm25_index = Arc::new(BM25Index::new().unwrap());
|
|
|
|
// Create config with embeddings enabled
|
|
let config = RLMEngineConfig {
|
|
embedding: Some(EmbeddingConfig::openai_small()),
|
|
..Default::default()
|
|
};
|
|
|
|
let engine = RLMEngine::with_config(storage.clone(), bm25_index, config).unwrap();
|
|
|
|
// Load document
|
|
let content = "First chunk. Second chunk. Third chunk.";
|
|
engine.load_document("doc-8", content, None).await.unwrap();
|
|
|
|
// Verify chunks have embeddings
|
|
let chunks = storage.get_chunks("doc-8").await.unwrap();
|
|
assert!(!chunks.is_empty(), "Should have created chunks");
|
|
|
|
for chunk in &chunks {
|
|
assert!(
|
|
chunk.embedding.is_some(),
|
|
"Chunk {} should have embedding",
|
|
chunk.chunk_id
|
|
);
|
|
assert_eq!(
|
|
chunk.embedding.as_ref().unwrap().len(),
|
|
1536,
|
|
"Embedding should have 1536 dimensions (OpenAI small)"
|
|
);
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_embeddings_disabled() {
|
|
let storage = Arc::new(MockStorage::new());
|
|
let bm25_index = Arc::new(BM25Index::new().unwrap());
|
|
|
|
// Create config with embeddings disabled
|
|
let config = RLMEngineConfig {
|
|
embedding: None,
|
|
..Default::default()
|
|
};
|
|
|
|
let engine = RLMEngine::with_config(storage.clone(), bm25_index, config).unwrap();
|
|
|
|
// Load document
|
|
let content = "Test content without embeddings";
|
|
engine.load_document("doc-9", content, None).await.unwrap();
|
|
|
|
// Verify chunks do NOT have embeddings
|
|
let chunks = storage.get_chunks("doc-9").await.unwrap();
|
|
assert!(!chunks.is_empty(), "Should have created chunks");
|
|
|
|
for chunk in &chunks {
|
|
assert!(
|
|
chunk.embedding.is_none(),
|
|
"Chunk {} should not have embedding when disabled",
|
|
chunk.chunk_id
|
|
);
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_query_with_embeddings() {
|
|
use crate::embeddings::EmbeddingConfig;
|
|
|
|
let storage = Arc::new(MockStorage::new());
|
|
let bm25_index = Arc::new(BM25Index::new().unwrap());
|
|
|
|
// Create config with embeddings enabled
|
|
let config = RLMEngineConfig {
|
|
embedding: Some(EmbeddingConfig::openai_small()),
|
|
..Default::default()
|
|
};
|
|
|
|
let engine = RLMEngine::with_config(storage.clone(), bm25_index, config).unwrap();
|
|
|
|
// Load document with embeddings
|
|
let content = "Rust programming language. Python tutorial. JavaScript guide.";
|
|
engine.load_document("doc-10", content, None).await.unwrap();
|
|
|
|
// Get a chunk to use its embedding as query
|
|
let chunks = storage.get_chunks("doc-10").await.unwrap();
|
|
assert!(!chunks.is_empty());
|
|
let query_embedding = chunks[0].embedding.as_ref().unwrap();
|
|
|
|
// Query with embedding (hybrid search)
|
|
let results = engine
|
|
.query("doc-10", "Rust", Some(query_embedding), 3)
|
|
.await
|
|
.unwrap();
|
|
|
|
assert!(!results.is_empty(), "Should find results");
|
|
// With real embeddings, should get both BM25 and semantic scores
|
|
}
|
|
}
|