449 lines
15 KiB
Rust
449 lines
15 KiB
Rust
// SurrealDB Storage Adapter for RLM
|
|
// Follows KGPersistence pattern from vapora-knowledge-graph
|
|
|
|
use std::sync::Arc;
|
|
|
|
use async_trait::async_trait;
|
|
use chrono::Utc;
|
|
use surrealdb::engine::remote::ws::Client;
|
|
use surrealdb::Surreal;
|
|
use tracing::{debug, error};
|
|
|
|
use super::{Buffer, Chunk, ExecutionHistory, Storage};
|
|
use crate::metrics::STORAGE_OPERATIONS;
|
|
use crate::RLMError;
|
|
|
|
/// SurrealDB storage implementation for RLM
|
|
pub struct SurrealDBStorage {
|
|
db: Arc<Surreal<Client>>,
|
|
}
|
|
|
|
impl std::fmt::Debug for SurrealDBStorage {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
f.debug_struct("SurrealDBStorage")
|
|
.field("db", &"<SurrealDB>")
|
|
.finish()
|
|
}
|
|
}
|
|
|
|
impl SurrealDBStorage {
|
|
/// Create new SurrealDB storage
|
|
pub fn new(db: Surreal<Client>) -> Self {
|
|
Self { db: Arc::new(db) }
|
|
}
|
|
|
|
/// Create from Arc (for sharing across components)
|
|
pub fn from_arc(db: Arc<Surreal<Client>>) -> Self {
|
|
Self { db }
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl Storage for SurrealDBStorage {
|
|
async fn save_chunk(&self, chunk: Chunk) -> crate::Result<()> {
|
|
debug!(
|
|
"Saving chunk {} for document {}",
|
|
chunk.chunk_id, chunk.doc_id
|
|
);
|
|
|
|
let query = "CREATE rlm_chunks SET chunk_id = $chunk_id, doc_id = $doc_id, content = \
|
|
$content, embedding = $embedding, start_idx = $start_idx, end_idx = \
|
|
$end_idx, metadata = $metadata, created_at = $created_at";
|
|
|
|
let result = self
|
|
.db
|
|
.query(query)
|
|
.bind(("chunk_id", chunk.chunk_id.clone()))
|
|
.bind(("doc_id", chunk.doc_id.clone()))
|
|
.bind(("content", chunk.content.clone()))
|
|
.bind(("embedding", chunk.embedding.clone()))
|
|
.bind(("start_idx", chunk.start_idx as i64))
|
|
.bind(("end_idx", chunk.end_idx as i64))
|
|
.bind(("metadata", chunk.metadata.clone()))
|
|
.bind(("created_at", chunk.created_at.clone()))
|
|
.await;
|
|
|
|
match result {
|
|
Ok(_) => {
|
|
STORAGE_OPERATIONS
|
|
.with_label_values(&["save_chunk", "success"])
|
|
.inc();
|
|
Ok(())
|
|
}
|
|
Err(e) => {
|
|
error!("Failed to save chunk {}: {}", chunk.chunk_id, e);
|
|
STORAGE_OPERATIONS
|
|
.with_label_values(&["save_chunk", "error"])
|
|
.inc();
|
|
Err(RLMError::DatabaseError(Box::new(e)))
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn get_chunks(&self, doc_id: &str) -> crate::Result<Vec<Chunk>> {
|
|
debug!("Fetching chunks for document {}", doc_id);
|
|
|
|
let query = "SELECT * FROM rlm_chunks WHERE doc_id = $doc_id ORDER BY start_idx ASC";
|
|
let mut response = self
|
|
.db
|
|
.query(query)
|
|
.bind(("doc_id", doc_id.to_string()))
|
|
.await
|
|
.map_err(|e| {
|
|
error!("Failed to fetch chunks for doc {}: {}", doc_id, e);
|
|
STORAGE_OPERATIONS
|
|
.with_label_values(&["get_chunks", "error"])
|
|
.inc();
|
|
RLMError::DatabaseError(Box::new(e))
|
|
})?;
|
|
|
|
let results: Vec<Chunk> = response.take(0).map_err(|e| {
|
|
error!("Failed to parse chunks for doc {}: {}", doc_id, e);
|
|
RLMError::DatabaseError(Box::new(e))
|
|
})?;
|
|
|
|
STORAGE_OPERATIONS
|
|
.with_label_values(&["get_chunks", "success"])
|
|
.inc();
|
|
Ok(results)
|
|
}
|
|
|
|
async fn get_chunk(&self, chunk_id: &str) -> crate::Result<Option<Chunk>> {
|
|
debug!("Fetching chunk {}", chunk_id);
|
|
|
|
let query = "SELECT * FROM rlm_chunks WHERE chunk_id = $chunk_id LIMIT 1";
|
|
let mut response = self
|
|
.db
|
|
.query(query)
|
|
.bind(("chunk_id", chunk_id.to_string()))
|
|
.await
|
|
.map_err(|e| {
|
|
error!("Failed to fetch chunk {}: {}", chunk_id, e);
|
|
STORAGE_OPERATIONS
|
|
.with_label_values(&["get_chunk", "error"])
|
|
.inc();
|
|
RLMError::DatabaseError(Box::new(e))
|
|
})?;
|
|
|
|
let results: Vec<Chunk> = response.take(0).map_err(|e| {
|
|
error!("Failed to parse chunk {}: {}", chunk_id, e);
|
|
RLMError::DatabaseError(Box::new(e))
|
|
})?;
|
|
|
|
STORAGE_OPERATIONS
|
|
.with_label_values(&["get_chunk", "success"])
|
|
.inc();
|
|
Ok(results.into_iter().next())
|
|
}
|
|
|
|
async fn search_by_embedding(
|
|
&self,
|
|
embedding: &[f32],
|
|
limit: usize,
|
|
) -> crate::Result<Vec<Chunk>> {
|
|
debug!("Searching for similar chunks (limit: {})", limit);
|
|
|
|
// SurrealDB vector similarity search
|
|
// For now, return recent chunks with embeddings
|
|
// TODO: Implement proper vector similarity when SurrealDB supports it
|
|
let query = "SELECT * FROM rlm_chunks WHERE embedding != NONE ORDER BY created_at DESC \
|
|
LIMIT $limit";
|
|
let mut response = self
|
|
.db
|
|
.query(query)
|
|
.bind(("limit", limit as i64))
|
|
.await
|
|
.map_err(|e| {
|
|
error!("Failed to search by embedding: {}", e);
|
|
STORAGE_OPERATIONS
|
|
.with_label_values(&["search_embedding", "error"])
|
|
.inc();
|
|
RLMError::DatabaseError(Box::new(e))
|
|
})?;
|
|
|
|
let results: Vec<Chunk> = response.take(0).map_err(|e| {
|
|
error!("Failed to parse embedding search results: {}", e);
|
|
RLMError::DatabaseError(Box::new(e))
|
|
})?;
|
|
|
|
STORAGE_OPERATIONS
|
|
.with_label_values(&["search_embedding", "success"])
|
|
.inc();
|
|
|
|
// Filter and rank by cosine similarity (in-memory for now)
|
|
let ranked = self.rank_by_similarity(&results, embedding, limit);
|
|
Ok(ranked)
|
|
}
|
|
|
|
async fn save_buffer(&self, buffer: Buffer) -> crate::Result<()> {
|
|
debug!("Saving buffer {}", buffer.buffer_id);
|
|
|
|
let query = "CREATE rlm_buffers SET buffer_id = $buffer_id, content = $content, metadata \
|
|
= $metadata, expires_at = $expires_at, created_at = $created_at";
|
|
|
|
let result = self
|
|
.db
|
|
.query(query)
|
|
.bind(("buffer_id", buffer.buffer_id.clone()))
|
|
.bind(("content", buffer.content.clone()))
|
|
.bind(("metadata", buffer.metadata.clone()))
|
|
.bind(("expires_at", buffer.expires_at.clone()))
|
|
.bind(("created_at", buffer.created_at.clone()))
|
|
.await;
|
|
|
|
match result {
|
|
Ok(_) => {
|
|
STORAGE_OPERATIONS
|
|
.with_label_values(&["save_buffer", "success"])
|
|
.inc();
|
|
Ok(())
|
|
}
|
|
Err(e) => {
|
|
error!("Failed to save buffer {}: {}", buffer.buffer_id, e);
|
|
STORAGE_OPERATIONS
|
|
.with_label_values(&["save_buffer", "error"])
|
|
.inc();
|
|
Err(RLMError::DatabaseError(Box::new(e)))
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn get_buffer(&self, buffer_id: &str) -> crate::Result<Option<Buffer>> {
|
|
debug!("Fetching buffer {}", buffer_id);
|
|
|
|
let query = "SELECT * FROM rlm_buffers WHERE buffer_id = $buffer_id LIMIT 1";
|
|
let mut response = self
|
|
.db
|
|
.query(query)
|
|
.bind(("buffer_id", buffer_id.to_string()))
|
|
.await
|
|
.map_err(|e| {
|
|
error!("Failed to fetch buffer {}: {}", buffer_id, e);
|
|
STORAGE_OPERATIONS
|
|
.with_label_values(&["get_buffer", "error"])
|
|
.inc();
|
|
RLMError::DatabaseError(Box::new(e))
|
|
})?;
|
|
|
|
let results: Vec<Buffer> = response.take(0).map_err(|e| {
|
|
error!("Failed to parse buffer {}: {}", buffer_id, e);
|
|
RLMError::DatabaseError(Box::new(e))
|
|
})?;
|
|
|
|
STORAGE_OPERATIONS
|
|
.with_label_values(&["get_buffer", "success"])
|
|
.inc();
|
|
Ok(results.into_iter().next())
|
|
}
|
|
|
|
async fn cleanup_expired_buffers(&self) -> crate::Result<u64> {
|
|
debug!("Cleaning up expired buffers");
|
|
|
|
let now = Utc::now().to_rfc3339();
|
|
let query = "DELETE FROM rlm_buffers WHERE expires_at != NONE AND expires_at < $now";
|
|
|
|
let mut response = self.db.query(query).bind(("now", now)).await.map_err(|e| {
|
|
error!("Failed to cleanup expired buffers: {}", e);
|
|
STORAGE_OPERATIONS
|
|
.with_label_values(&["cleanup_buffers", "error"])
|
|
.inc();
|
|
RLMError::DatabaseError(Box::new(e))
|
|
})?;
|
|
|
|
// SurrealDB 2.x doesn't return delete count easily
|
|
let _: Vec<serde_json::Value> = response.take(0).unwrap_or_default();
|
|
|
|
STORAGE_OPERATIONS
|
|
.with_label_values(&["cleanup_buffers", "success"])
|
|
.inc();
|
|
Ok(0)
|
|
}
|
|
|
|
async fn save_execution(&self, execution: ExecutionHistory) -> crate::Result<()> {
|
|
debug!(
|
|
"Saving execution {} for document {}",
|
|
execution.execution_id, execution.doc_id
|
|
);
|
|
|
|
let query = "CREATE rlm_executions SET execution_id = $execution_id, doc_id = $doc_id, \
|
|
query = $query, chunks_used = $chunks_used, result = $result, duration_ms = \
|
|
$duration_ms, cost_cents = $cost_cents, provider = $provider, success = \
|
|
$success, error_message = $error_message, metadata = $metadata, created_at = \
|
|
$created_at, executed_at = $executed_at";
|
|
|
|
let result = self
|
|
.db
|
|
.query(query)
|
|
.bind(("execution_id", execution.execution_id.clone()))
|
|
.bind(("doc_id", execution.doc_id.clone()))
|
|
.bind(("query", execution.query.clone()))
|
|
.bind(("chunks_used", execution.chunks_used.clone()))
|
|
.bind(("result", execution.result.clone()))
|
|
.bind(("duration_ms", execution.duration_ms as i64))
|
|
.bind(("cost_cents", execution.cost_cents))
|
|
.bind(("provider", execution.provider.clone()))
|
|
.bind(("success", execution.success))
|
|
.bind(("error_message", execution.error_message.clone()))
|
|
.bind(("metadata", execution.metadata.clone()))
|
|
.bind(("created_at", execution.created_at.clone()))
|
|
.bind(("executed_at", execution.executed_at.clone()))
|
|
.await;
|
|
|
|
match result {
|
|
Ok(_) => {
|
|
STORAGE_OPERATIONS
|
|
.with_label_values(&["save_execution", "success"])
|
|
.inc();
|
|
Ok(())
|
|
}
|
|
Err(e) => {
|
|
error!("Failed to save execution {}: {}", execution.execution_id, e);
|
|
STORAGE_OPERATIONS
|
|
.with_label_values(&["save_execution", "error"])
|
|
.inc();
|
|
Err(RLMError::DatabaseError(Box::new(e)))
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn get_executions(
|
|
&self,
|
|
doc_id: &str,
|
|
limit: usize,
|
|
) -> crate::Result<Vec<ExecutionHistory>> {
|
|
debug!(
|
|
"Fetching executions for document {} (limit: {})",
|
|
doc_id, limit
|
|
);
|
|
|
|
let query = "SELECT * FROM rlm_executions WHERE doc_id = $doc_id ORDER BY executed_at \
|
|
DESC LIMIT $limit";
|
|
let mut response = self
|
|
.db
|
|
.query(query)
|
|
.bind(("doc_id", doc_id.to_string()))
|
|
.bind(("limit", limit as i64))
|
|
.await
|
|
.map_err(|e| {
|
|
error!("Failed to fetch executions for doc {}: {}", doc_id, e);
|
|
STORAGE_OPERATIONS
|
|
.with_label_values(&["get_executions", "error"])
|
|
.inc();
|
|
RLMError::DatabaseError(Box::new(e))
|
|
})?;
|
|
|
|
let results: Vec<ExecutionHistory> = response.take(0).map_err(|e| {
|
|
error!("Failed to parse executions for doc {}: {}", doc_id, e);
|
|
RLMError::DatabaseError(Box::new(e))
|
|
})?;
|
|
|
|
STORAGE_OPERATIONS
|
|
.with_label_values(&["get_executions", "success"])
|
|
.inc();
|
|
Ok(results)
|
|
}
|
|
|
|
async fn delete_chunks(&self, doc_id: &str) -> crate::Result<u64> {
|
|
debug!("Deleting chunks for document {}", doc_id);
|
|
|
|
let query = "DELETE FROM rlm_chunks WHERE doc_id = $doc_id";
|
|
let mut response = self
|
|
.db
|
|
.query(query)
|
|
.bind(("doc_id", doc_id.to_string()))
|
|
.await
|
|
.map_err(|e| {
|
|
error!("Failed to delete chunks for doc {}: {}", doc_id, e);
|
|
STORAGE_OPERATIONS
|
|
.with_label_values(&["delete_chunks", "error"])
|
|
.inc();
|
|
RLMError::DatabaseError(Box::new(e))
|
|
})?;
|
|
|
|
// SurrealDB 2.x doesn't return delete count easily
|
|
let _: Vec<serde_json::Value> = response.take(0).unwrap_or_default();
|
|
|
|
STORAGE_OPERATIONS
|
|
.with_label_values(&["delete_chunks", "success"])
|
|
.inc();
|
|
Ok(0)
|
|
}
|
|
}
|
|
|
|
impl SurrealDBStorage {
|
|
/// Rank chunks by cosine similarity to query embedding (in-memory)
|
|
fn rank_by_similarity(
|
|
&self,
|
|
chunks: &[Chunk],
|
|
query_embedding: &[f32],
|
|
limit: usize,
|
|
) -> Vec<Chunk> {
|
|
let mut scored: Vec<(f32, Chunk)> = chunks
|
|
.iter()
|
|
.filter_map(|chunk| {
|
|
if let Some(ref embedding) = chunk.embedding {
|
|
let similarity = cosine_similarity(embedding, query_embedding);
|
|
Some((similarity, chunk.clone()))
|
|
} else {
|
|
None
|
|
}
|
|
})
|
|
.collect();
|
|
|
|
// Sort by similarity descending
|
|
scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
|
|
|
|
// Take top N
|
|
scored
|
|
.into_iter()
|
|
.take(limit)
|
|
.map(|(_, chunk)| chunk)
|
|
.collect()
|
|
}
|
|
}
|
|
|
|
/// Cosine similarity between two vectors
|
|
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
|
if a.len() != b.len() {
|
|
return 0.0;
|
|
}
|
|
|
|
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
|
|
let magnitude_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
|
let magnitude_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
|
|
|
if magnitude_a == 0.0 || magnitude_b == 0.0 {
|
|
return 0.0;
|
|
}
|
|
|
|
dot_product / (magnitude_a * magnitude_b)
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_cosine_similarity() {
|
|
let a = vec![1.0, 0.0, 0.0];
|
|
let b = vec![1.0, 0.0, 0.0];
|
|
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
|
|
|
|
let c = vec![1.0, 0.0, 0.0];
|
|
let d = vec![0.0, 1.0, 0.0];
|
|
assert!((cosine_similarity(&c, &d) - 0.0).abs() < 0.001);
|
|
|
|
let e = vec![1.0, 1.0, 0.0];
|
|
let f = vec![1.0, 0.0, 0.0];
|
|
let similarity = cosine_similarity(&e, &f);
|
|
assert!(similarity > 0.7 && similarity < 0.8);
|
|
}
|
|
|
|
#[test]
|
|
fn test_cosine_similarity_edge_cases() {
|
|
assert_eq!(cosine_similarity(&[], &[]), 0.0);
|
|
assert_eq!(cosine_similarity(&[1.0], &[1.0, 2.0]), 0.0);
|
|
assert_eq!(cosine_similarity(&[0.0, 0.0], &[1.0, 1.0]), 0.0);
|
|
}
|
|
}
|