325 lines
9.7 KiB
Rust
325 lines
9.7 KiB
Rust
|
|
// RLM Integration Tests
|
||
|
|
// Tests require SurrealDB to be running: docker run -p 8000:8000
|
||
|
|
// surrealdb/surrealdb:latest start --bind 0.0.0.0:8000
|
||
|
|
|
||
|
|
use surrealdb::engine::remote::ws::Ws;
|
||
|
|
use surrealdb::opt::auth::Root;
|
||
|
|
use surrealdb::Surreal;
|
||
|
|
use vapora_knowledge_graph::persistence::{KGPersistence, PersistedRlmExecution};
|
||
|
|
use vapora_knowledge_graph::TimePeriod;
|
||
|
|
|
||
|
|
async fn setup_test_db() -> KGPersistence {
|
||
|
|
let db = Surreal::new::<Ws>("127.0.0.1:8000").await.unwrap();
|
||
|
|
|
||
|
|
db.signin(Root {
|
||
|
|
username: "root",
|
||
|
|
password: "root",
|
||
|
|
})
|
||
|
|
.await
|
||
|
|
.unwrap();
|
||
|
|
|
||
|
|
db.use_ns("test_rlm").use_db("test_rlm").await.unwrap();
|
||
|
|
|
||
|
|
KGPersistence::new(db)
|
||
|
|
}
|
||
|
|
|
||
|
|
#[tokio::test]
|
||
|
|
#[ignore] // Requires SurrealDB
|
||
|
|
async fn test_persist_rlm_execution() {
|
||
|
|
let persistence = setup_test_db().await;
|
||
|
|
|
||
|
|
let execution = PersistedRlmExecution::builder(
|
||
|
|
format!("test-exec-{}", uuid::Uuid::new_v4()),
|
||
|
|
"doc-1".to_string(),
|
||
|
|
"What is Rust?".to_string(),
|
||
|
|
)
|
||
|
|
.chunks_used(vec!["chunk-1".to_string(), "chunk-2".to_string()])
|
||
|
|
.result("Rust is a systems programming language".to_string())
|
||
|
|
.duration_ms(3000)
|
||
|
|
.tokens(800, 400)
|
||
|
|
.num_llm_calls(2)
|
||
|
|
.provider("claude".to_string())
|
||
|
|
.success(true)
|
||
|
|
.cost_cents(120.0)
|
||
|
|
.aggregation_strategy("Concatenate".to_string())
|
||
|
|
.build();
|
||
|
|
|
||
|
|
let exec_id = execution.execution_id.clone();
|
||
|
|
let doc_id = execution.doc_id.clone();
|
||
|
|
|
||
|
|
let result = persistence.persist_rlm_execution(execution).await;
|
||
|
|
assert!(result.is_ok(), "Failed to persist RLM execution");
|
||
|
|
|
||
|
|
// Verify can retrieve
|
||
|
|
let executions = persistence
|
||
|
|
.get_rlm_executions_by_doc(&doc_id, 10)
|
||
|
|
.await
|
||
|
|
.unwrap();
|
||
|
|
assert!(!executions.is_empty());
|
||
|
|
assert_eq!(executions[0].execution_id, exec_id);
|
||
|
|
}
|
||
|
|
|
||
|
|
#[tokio::test]
|
||
|
|
#[ignore] // Requires SurrealDB
|
||
|
|
async fn test_get_rlm_learning_curve() {
|
||
|
|
let persistence = setup_test_db().await;
|
||
|
|
let doc_id = format!("doc-{}", uuid::Uuid::new_v4());
|
||
|
|
|
||
|
|
// Create multiple executions over time
|
||
|
|
for i in 0..10 {
|
||
|
|
let execution = PersistedRlmExecution::builder(
|
||
|
|
format!("exec-{}-{}", doc_id, i),
|
||
|
|
doc_id.clone(),
|
||
|
|
format!("Query {}", i),
|
||
|
|
)
|
||
|
|
.chunks_used(vec![format!("chunk-{}", i)])
|
||
|
|
.result(format!("Result {}", i))
|
||
|
|
.duration_ms(1000 + (i as u64 * 100))
|
||
|
|
.tokens(800, 400)
|
||
|
|
.provider("claude".to_string())
|
||
|
|
.success(i % 3 != 0) // Success rate ~66%
|
||
|
|
.build();
|
||
|
|
|
||
|
|
let mut exec = execution;
|
||
|
|
if i % 3 == 0 {
|
||
|
|
// Add error for failed executions
|
||
|
|
exec = PersistedRlmExecution::builder(
|
||
|
|
format!("exec-{}-{}", doc_id, i),
|
||
|
|
doc_id.clone(),
|
||
|
|
format!("Query {}", i),
|
||
|
|
)
|
||
|
|
.chunks_used(vec![format!("chunk-{}", i)])
|
||
|
|
.result(format!("Result {}", i))
|
||
|
|
.duration_ms(1000 + (i as u64 * 100))
|
||
|
|
.tokens(800, 400)
|
||
|
|
.provider("claude".to_string())
|
||
|
|
.error("Test error".to_string())
|
||
|
|
.build();
|
||
|
|
}
|
||
|
|
|
||
|
|
persistence.persist_rlm_execution(exec).await.unwrap();
|
||
|
|
|
||
|
|
// Small delay to ensure different timestamps
|
||
|
|
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
|
||
|
|
}
|
||
|
|
|
||
|
|
// Get learning curve
|
||
|
|
let curve = persistence
|
||
|
|
.get_rlm_learning_curve(&doc_id, 1)
|
||
|
|
.await
|
||
|
|
.unwrap();
|
||
|
|
assert!(!curve.is_empty(), "Learning curve should not be empty");
|
||
|
|
|
||
|
|
// Verify chronological ordering
|
||
|
|
for i in 1..curve.len() {
|
||
|
|
assert!(
|
||
|
|
curve[i - 1].0 <= curve[i].0,
|
||
|
|
"Curve must be chronologically sorted"
|
||
|
|
);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
#[tokio::test]
|
||
|
|
#[ignore] // Requires SurrealDB
|
||
|
|
async fn test_get_rlm_success_rate() {
|
||
|
|
let persistence = setup_test_db().await;
|
||
|
|
let doc_id = format!("doc-{}", uuid::Uuid::new_v4());
|
||
|
|
|
||
|
|
// Create 10 executions: 7 successes, 3 failures
|
||
|
|
for i in 0..10 {
|
||
|
|
let mut builder = PersistedRlmExecution::builder(
|
||
|
|
format!("exec-{}-{}", doc_id, i),
|
||
|
|
doc_id.clone(),
|
||
|
|
format!("Query {}", i),
|
||
|
|
)
|
||
|
|
.result("Result".to_string())
|
||
|
|
.duration_ms(1000)
|
||
|
|
.tokens(800, 400)
|
||
|
|
.provider("claude".to_string());
|
||
|
|
|
||
|
|
if i < 7 {
|
||
|
|
builder = builder.success(true);
|
||
|
|
} else {
|
||
|
|
builder = builder.error("Error".to_string());
|
||
|
|
}
|
||
|
|
|
||
|
|
let execution = builder.build();
|
||
|
|
persistence.persist_rlm_execution(execution).await.unwrap();
|
||
|
|
}
|
||
|
|
|
||
|
|
let success_rate = persistence.get_rlm_success_rate(&doc_id).await.unwrap();
|
||
|
|
assert!(
|
||
|
|
(success_rate - 0.7).abs() < 0.01,
|
||
|
|
"Success rate should be ~0.7, got {}",
|
||
|
|
success_rate
|
||
|
|
);
|
||
|
|
}
|
||
|
|
|
||
|
|
#[tokio::test]
|
||
|
|
#[ignore] // Requires SurrealDB
|
||
|
|
async fn test_find_similar_rlm_tasks() {
|
||
|
|
let persistence = setup_test_db().await;
|
||
|
|
|
||
|
|
// Create multiple successful executions
|
||
|
|
for i in 0..5 {
|
||
|
|
let execution = PersistedRlmExecution::builder(
|
||
|
|
format!("exec-similar-{}", i),
|
||
|
|
format!("doc-{}", i),
|
||
|
|
format!("Query about topic {}", i % 3),
|
||
|
|
)
|
||
|
|
.result("Result".to_string())
|
||
|
|
.duration_ms(1000)
|
||
|
|
.tokens(800, 400)
|
||
|
|
.provider("claude".to_string())
|
||
|
|
.success(true)
|
||
|
|
.query_embedding(vec![0.1 * i as f32; 1536])
|
||
|
|
.build();
|
||
|
|
|
||
|
|
persistence.persist_rlm_execution(execution).await.unwrap();
|
||
|
|
}
|
||
|
|
|
||
|
|
// Search for similar tasks
|
||
|
|
let query_embedding = vec![0.1; 1536];
|
||
|
|
let similar = persistence
|
||
|
|
.find_similar_rlm_tasks(&query_embedding, 3)
|
||
|
|
.await
|
||
|
|
.unwrap();
|
||
|
|
|
||
|
|
assert!(!similar.is_empty(), "Should find similar tasks");
|
||
|
|
assert!(similar.len() <= 3, "Should respect limit");
|
||
|
|
|
||
|
|
// All returned tasks should be successful
|
||
|
|
for task in similar {
|
||
|
|
assert!(task.success, "Similar tasks should be successful");
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
#[tokio::test]
|
||
|
|
#[ignore] // Requires SurrealDB
|
||
|
|
async fn test_get_rlm_cost_summary() {
|
||
|
|
let persistence = setup_test_db().await;
|
||
|
|
let doc_id = format!("doc-{}", uuid::Uuid::new_v4());
|
||
|
|
|
||
|
|
// Create executions with known costs and tokens
|
||
|
|
for i in 0..5 {
|
||
|
|
let execution = PersistedRlmExecution::builder(
|
||
|
|
format!("exec-{}-{}", doc_id, i),
|
||
|
|
doc_id.clone(),
|
||
|
|
format!("Query {}", i),
|
||
|
|
)
|
||
|
|
.result("Result".to_string())
|
||
|
|
.duration_ms(1000)
|
||
|
|
.tokens(1000, 500) // 1k input, 500 output
|
||
|
|
.provider("claude".to_string())
|
||
|
|
.success(true)
|
||
|
|
.cost_cents(100.0) // 100 cents each
|
||
|
|
.build();
|
||
|
|
|
||
|
|
persistence.persist_rlm_execution(execution).await.unwrap();
|
||
|
|
}
|
||
|
|
|
||
|
|
let (total_cost, input_tokens, output_tokens) = persistence
|
||
|
|
.get_rlm_cost_summary(&doc_id, TimePeriod::LastDay)
|
||
|
|
.await
|
||
|
|
.unwrap();
|
||
|
|
|
||
|
|
assert_eq!(total_cost, 500.0, "Total cost should be 500 cents");
|
||
|
|
assert_eq!(input_tokens, 5000, "Total input tokens should be 5000");
|
||
|
|
assert_eq!(output_tokens, 2500, "Total output tokens should be 2500");
|
||
|
|
}
|
||
|
|
|
||
|
|
#[tokio::test]
|
||
|
|
#[ignore] // Requires SurrealDB
|
||
|
|
async fn test_batch_persist_rlm_executions() {
|
||
|
|
let persistence = setup_test_db().await;
|
||
|
|
let doc_id = format!("doc-{}", uuid::Uuid::new_v4());
|
||
|
|
|
||
|
|
let executions: Vec<PersistedRlmExecution> = (0..5)
|
||
|
|
.map(|i| {
|
||
|
|
PersistedRlmExecution::builder(
|
||
|
|
format!("batch-exec-{}", i),
|
||
|
|
doc_id.clone(),
|
||
|
|
format!("Batch query {}", i),
|
||
|
|
)
|
||
|
|
.result("Result".to_string())
|
||
|
|
.duration_ms(1000)
|
||
|
|
.tokens(800, 400)
|
||
|
|
.provider("claude".to_string())
|
||
|
|
.success(true)
|
||
|
|
.build()
|
||
|
|
})
|
||
|
|
.collect();
|
||
|
|
|
||
|
|
let result = persistence.persist_rlm_executions(executions).await;
|
||
|
|
assert!(result.is_ok(), "Batch persist should succeed");
|
||
|
|
|
||
|
|
let retrieved = persistence
|
||
|
|
.get_rlm_executions_by_doc(&doc_id, 10)
|
||
|
|
.await
|
||
|
|
.unwrap();
|
||
|
|
assert_eq!(retrieved.len(), 5, "Should retrieve all 5 executions");
|
||
|
|
}
|
||
|
|
|
||
|
|
#[tokio::test]
|
||
|
|
#[ignore] // Requires SurrealDB
|
||
|
|
async fn test_cleanup_old_rlm_executions() {
|
||
|
|
let persistence = setup_test_db().await;
|
||
|
|
let doc_id = format!("doc-cleanup-{}", uuid::Uuid::new_v4());
|
||
|
|
|
||
|
|
// Create an old execution
|
||
|
|
let old_execution = PersistedRlmExecution::builder(
|
||
|
|
format!("old-exec-{}", uuid::Uuid::new_v4()),
|
||
|
|
doc_id.clone(),
|
||
|
|
"Old query".to_string(),
|
||
|
|
)
|
||
|
|
.result("Result".to_string())
|
||
|
|
.duration_ms(1000)
|
||
|
|
.tokens(800, 400)
|
||
|
|
.provider("claude".to_string())
|
||
|
|
.success(true)
|
||
|
|
.build();
|
||
|
|
|
||
|
|
persistence
|
||
|
|
.persist_rlm_execution(old_execution)
|
||
|
|
.await
|
||
|
|
.unwrap();
|
||
|
|
|
||
|
|
// Cleanup executions older than 0 days (should delete all)
|
||
|
|
let result = persistence.cleanup_old_rlm_executions(0).await;
|
||
|
|
assert!(result.is_ok(), "Cleanup should succeed");
|
||
|
|
|
||
|
|
// Note: SurrealDB doesn't return delete count, so we can't verify count
|
||
|
|
// But we can verify the operation completed without error
|
||
|
|
}
|
||
|
|
|
||
|
|
#[tokio::test]
|
||
|
|
#[ignore] // Requires SurrealDB
|
||
|
|
async fn test_get_rlm_execution_count() {
|
||
|
|
let persistence = setup_test_db().await;
|
||
|
|
|
||
|
|
let initial_count = persistence.get_rlm_execution_count().await.unwrap();
|
||
|
|
|
||
|
|
// Add 3 executions
|
||
|
|
for i in 0..3 {
|
||
|
|
let execution = PersistedRlmExecution::builder(
|
||
|
|
format!("count-exec-{}", i),
|
||
|
|
"count-doc".to_string(),
|
||
|
|
"Query".to_string(),
|
||
|
|
)
|
||
|
|
.result("Result".to_string())
|
||
|
|
.duration_ms(1000)
|
||
|
|
.tokens(800, 400)
|
||
|
|
.provider("claude".to_string())
|
||
|
|
.success(true)
|
||
|
|
.build();
|
||
|
|
|
||
|
|
persistence.persist_rlm_execution(execution).await.unwrap();
|
||
|
|
}
|
||
|
|
|
||
|
|
let final_count = persistence.get_rlm_execution_count().await.unwrap();
|
||
|
|
assert!(
|
||
|
|
final_count >= initial_count + 3,
|
||
|
|
"Count should increase by at least 3"
|
||
|
|
);
|
||
|
|
}
|