325 lines
9.7 KiB
Rust
Raw Normal View History

2026-02-16 05:09:51 +00:00
// RLM Integration Tests
// Tests require SurrealDB to be running: docker run -p 8000:8000
// surrealdb/surrealdb:latest start --bind 0.0.0.0:8000
use surrealdb::engine::remote::ws::Ws;
use surrealdb::opt::auth::Root;
use surrealdb::Surreal;
use vapora_knowledge_graph::persistence::{KGPersistence, PersistedRlmExecution};
use vapora_knowledge_graph::TimePeriod;
async fn setup_test_db() -> KGPersistence {
let db = Surreal::new::<Ws>("127.0.0.1:8000").await.unwrap();
db.signin(Root {
username: "root",
password: "root",
})
.await
.unwrap();
db.use_ns("test_rlm").use_db("test_rlm").await.unwrap();
KGPersistence::new(db)
}
#[tokio::test]
#[ignore] // Requires SurrealDB
async fn test_persist_rlm_execution() {
let persistence = setup_test_db().await;
let execution = PersistedRlmExecution::builder(
format!("test-exec-{}", uuid::Uuid::new_v4()),
"doc-1".to_string(),
"What is Rust?".to_string(),
)
.chunks_used(vec!["chunk-1".to_string(), "chunk-2".to_string()])
.result("Rust is a systems programming language".to_string())
.duration_ms(3000)
.tokens(800, 400)
.num_llm_calls(2)
.provider("claude".to_string())
.success(true)
.cost_cents(120.0)
.aggregation_strategy("Concatenate".to_string())
.build();
let exec_id = execution.execution_id.clone();
let doc_id = execution.doc_id.clone();
let result = persistence.persist_rlm_execution(execution).await;
assert!(result.is_ok(), "Failed to persist RLM execution");
// Verify can retrieve
let executions = persistence
.get_rlm_executions_by_doc(&doc_id, 10)
.await
.unwrap();
assert!(!executions.is_empty());
assert_eq!(executions[0].execution_id, exec_id);
}
#[tokio::test]
#[ignore] // Requires SurrealDB
async fn test_get_rlm_learning_curve() {
let persistence = setup_test_db().await;
let doc_id = format!("doc-{}", uuid::Uuid::new_v4());
// Create multiple executions over time
for i in 0..10 {
let execution = PersistedRlmExecution::builder(
format!("exec-{}-{}", doc_id, i),
doc_id.clone(),
format!("Query {}", i),
)
.chunks_used(vec![format!("chunk-{}", i)])
.result(format!("Result {}", i))
.duration_ms(1000 + (i as u64 * 100))
.tokens(800, 400)
.provider("claude".to_string())
.success(i % 3 != 0) // Success rate ~66%
.build();
let mut exec = execution;
if i % 3 == 0 {
// Add error for failed executions
exec = PersistedRlmExecution::builder(
format!("exec-{}-{}", doc_id, i),
doc_id.clone(),
format!("Query {}", i),
)
.chunks_used(vec![format!("chunk-{}", i)])
.result(format!("Result {}", i))
.duration_ms(1000 + (i as u64 * 100))
.tokens(800, 400)
.provider("claude".to_string())
.error("Test error".to_string())
.build();
}
persistence.persist_rlm_execution(exec).await.unwrap();
// Small delay to ensure different timestamps
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
}
// Get learning curve
let curve = persistence
.get_rlm_learning_curve(&doc_id, 1)
.await
.unwrap();
assert!(!curve.is_empty(), "Learning curve should not be empty");
// Verify chronological ordering
for i in 1..curve.len() {
assert!(
curve[i - 1].0 <= curve[i].0,
"Curve must be chronologically sorted"
);
}
}
#[tokio::test]
#[ignore] // Requires SurrealDB
async fn test_get_rlm_success_rate() {
let persistence = setup_test_db().await;
let doc_id = format!("doc-{}", uuid::Uuid::new_v4());
// Create 10 executions: 7 successes, 3 failures
for i in 0..10 {
let mut builder = PersistedRlmExecution::builder(
format!("exec-{}-{}", doc_id, i),
doc_id.clone(),
format!("Query {}", i),
)
.result("Result".to_string())
.duration_ms(1000)
.tokens(800, 400)
.provider("claude".to_string());
if i < 7 {
builder = builder.success(true);
} else {
builder = builder.error("Error".to_string());
}
let execution = builder.build();
persistence.persist_rlm_execution(execution).await.unwrap();
}
let success_rate = persistence.get_rlm_success_rate(&doc_id).await.unwrap();
assert!(
(success_rate - 0.7).abs() < 0.01,
"Success rate should be ~0.7, got {}",
success_rate
);
}
#[tokio::test]
#[ignore] // Requires SurrealDB
async fn test_find_similar_rlm_tasks() {
let persistence = setup_test_db().await;
// Create multiple successful executions
for i in 0..5 {
let execution = PersistedRlmExecution::builder(
format!("exec-similar-{}", i),
format!("doc-{}", i),
format!("Query about topic {}", i % 3),
)
.result("Result".to_string())
.duration_ms(1000)
.tokens(800, 400)
.provider("claude".to_string())
.success(true)
.query_embedding(vec![0.1 * i as f32; 1536])
.build();
persistence.persist_rlm_execution(execution).await.unwrap();
}
// Search for similar tasks
let query_embedding = vec![0.1; 1536];
let similar = persistence
.find_similar_rlm_tasks(&query_embedding, 3)
.await
.unwrap();
assert!(!similar.is_empty(), "Should find similar tasks");
assert!(similar.len() <= 3, "Should respect limit");
// All returned tasks should be successful
for task in similar {
assert!(task.success, "Similar tasks should be successful");
}
}
#[tokio::test]
#[ignore] // Requires SurrealDB
async fn test_get_rlm_cost_summary() {
let persistence = setup_test_db().await;
let doc_id = format!("doc-{}", uuid::Uuid::new_v4());
// Create executions with known costs and tokens
for i in 0..5 {
let execution = PersistedRlmExecution::builder(
format!("exec-{}-{}", doc_id, i),
doc_id.clone(),
format!("Query {}", i),
)
.result("Result".to_string())
.duration_ms(1000)
.tokens(1000, 500) // 1k input, 500 output
.provider("claude".to_string())
.success(true)
.cost_cents(100.0) // 100 cents each
.build();
persistence.persist_rlm_execution(execution).await.unwrap();
}
let (total_cost, input_tokens, output_tokens) = persistence
.get_rlm_cost_summary(&doc_id, TimePeriod::LastDay)
.await
.unwrap();
assert_eq!(total_cost, 500.0, "Total cost should be 500 cents");
assert_eq!(input_tokens, 5000, "Total input tokens should be 5000");
assert_eq!(output_tokens, 2500, "Total output tokens should be 2500");
}
#[tokio::test]
#[ignore] // Requires SurrealDB
async fn test_batch_persist_rlm_executions() {
let persistence = setup_test_db().await;
let doc_id = format!("doc-{}", uuid::Uuid::new_v4());
let executions: Vec<PersistedRlmExecution> = (0..5)
.map(|i| {
PersistedRlmExecution::builder(
format!("batch-exec-{}", i),
doc_id.clone(),
format!("Batch query {}", i),
)
.result("Result".to_string())
.duration_ms(1000)
.tokens(800, 400)
.provider("claude".to_string())
.success(true)
.build()
})
.collect();
let result = persistence.persist_rlm_executions(executions).await;
assert!(result.is_ok(), "Batch persist should succeed");
let retrieved = persistence
.get_rlm_executions_by_doc(&doc_id, 10)
.await
.unwrap();
assert_eq!(retrieved.len(), 5, "Should retrieve all 5 executions");
}
#[tokio::test]
#[ignore] // Requires SurrealDB
async fn test_cleanup_old_rlm_executions() {
let persistence = setup_test_db().await;
let doc_id = format!("doc-cleanup-{}", uuid::Uuid::new_v4());
// Create an old execution
let old_execution = PersistedRlmExecution::builder(
format!("old-exec-{}", uuid::Uuid::new_v4()),
doc_id.clone(),
"Old query".to_string(),
)
.result("Result".to_string())
.duration_ms(1000)
.tokens(800, 400)
.provider("claude".to_string())
.success(true)
.build();
persistence
.persist_rlm_execution(old_execution)
.await
.unwrap();
// Cleanup executions older than 0 days (should delete all)
let result = persistence.cleanup_old_rlm_executions(0).await;
assert!(result.is_ok(), "Cleanup should succeed");
// Note: SurrealDB doesn't return delete count, so we can't verify count
// But we can verify the operation completed without error
}
#[tokio::test]
#[ignore] // Requires SurrealDB
async fn test_get_rlm_execution_count() {
let persistence = setup_test_db().await;
let initial_count = persistence.get_rlm_execution_count().await.unwrap();
// Add 3 executions
for i in 0..3 {
let execution = PersistedRlmExecution::builder(
format!("count-exec-{}", i),
"count-doc".to_string(),
"Query".to_string(),
)
.result("Result".to_string())
.duration_ms(1000)
.tokens(800, 400)
.provider("claude".to_string())
.success(true)
.build();
persistence.persist_rlm_execution(execution).await.unwrap();
}
let final_count = persistence.get_rlm_execution_count().await.unwrap();
assert!(
final_count >= initial_count + 3,
"Count should increase by at least 3"
);
}