449 lines
15 KiB
Rust
Raw Normal View History

2026-02-16 05:09:51 +00:00
// SurrealDB Storage Adapter for RLM
// Follows KGPersistence pattern from vapora-knowledge-graph
use std::sync::Arc;
use async_trait::async_trait;
use chrono::Utc;
use surrealdb::engine::remote::ws::Client;
use surrealdb::Surreal;
use tracing::{debug, error};
use super::{Buffer, Chunk, ExecutionHistory, Storage};
use crate::metrics::STORAGE_OPERATIONS;
use crate::RLMError;
/// SurrealDB storage implementation for RLM
pub struct SurrealDBStorage {
db: Arc<Surreal<Client>>,
}
impl std::fmt::Debug for SurrealDBStorage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SurrealDBStorage")
.field("db", &"<SurrealDB>")
.finish()
}
}
impl SurrealDBStorage {
/// Create new SurrealDB storage
pub fn new(db: Surreal<Client>) -> Self {
Self { db: Arc::new(db) }
}
/// Create from Arc (for sharing across components)
pub fn from_arc(db: Arc<Surreal<Client>>) -> Self {
Self { db }
}
}
#[async_trait]
impl Storage for SurrealDBStorage {
async fn save_chunk(&self, chunk: Chunk) -> crate::Result<()> {
debug!(
"Saving chunk {} for document {}",
chunk.chunk_id, chunk.doc_id
);
let query = "CREATE rlm_chunks SET chunk_id = $chunk_id, doc_id = $doc_id, content = \
$content, embedding = $embedding, start_idx = $start_idx, end_idx = \
$end_idx, metadata = $metadata, created_at = $created_at";
let result = self
.db
.query(query)
.bind(("chunk_id", chunk.chunk_id.clone()))
.bind(("doc_id", chunk.doc_id.clone()))
.bind(("content", chunk.content.clone()))
.bind(("embedding", chunk.embedding.clone()))
.bind(("start_idx", chunk.start_idx as i64))
.bind(("end_idx", chunk.end_idx as i64))
.bind(("metadata", chunk.metadata.clone()))
.bind(("created_at", chunk.created_at.clone()))
.await;
match result {
Ok(_) => {
STORAGE_OPERATIONS
.with_label_values(&["save_chunk", "success"])
.inc();
Ok(())
}
Err(e) => {
error!("Failed to save chunk {}: {}", chunk.chunk_id, e);
STORAGE_OPERATIONS
.with_label_values(&["save_chunk", "error"])
.inc();
Err(RLMError::DatabaseError(Box::new(e)))
}
}
}
async fn get_chunks(&self, doc_id: &str) -> crate::Result<Vec<Chunk>> {
debug!("Fetching chunks for document {}", doc_id);
let query = "SELECT * FROM rlm_chunks WHERE doc_id = $doc_id ORDER BY start_idx ASC";
let mut response = self
.db
.query(query)
.bind(("doc_id", doc_id.to_string()))
.await
.map_err(|e| {
error!("Failed to fetch chunks for doc {}: {}", doc_id, e);
STORAGE_OPERATIONS
.with_label_values(&["get_chunks", "error"])
.inc();
RLMError::DatabaseError(Box::new(e))
})?;
let results: Vec<Chunk> = response.take(0).map_err(|e| {
error!("Failed to parse chunks for doc {}: {}", doc_id, e);
RLMError::DatabaseError(Box::new(e))
})?;
STORAGE_OPERATIONS
.with_label_values(&["get_chunks", "success"])
.inc();
Ok(results)
}
async fn get_chunk(&self, chunk_id: &str) -> crate::Result<Option<Chunk>> {
debug!("Fetching chunk {}", chunk_id);
let query = "SELECT * FROM rlm_chunks WHERE chunk_id = $chunk_id LIMIT 1";
let mut response = self
.db
.query(query)
.bind(("chunk_id", chunk_id.to_string()))
.await
.map_err(|e| {
error!("Failed to fetch chunk {}: {}", chunk_id, e);
STORAGE_OPERATIONS
.with_label_values(&["get_chunk", "error"])
.inc();
RLMError::DatabaseError(Box::new(e))
})?;
let results: Vec<Chunk> = response.take(0).map_err(|e| {
error!("Failed to parse chunk {}: {}", chunk_id, e);
RLMError::DatabaseError(Box::new(e))
})?;
STORAGE_OPERATIONS
.with_label_values(&["get_chunk", "success"])
.inc();
Ok(results.into_iter().next())
}
async fn search_by_embedding(
&self,
embedding: &[f32],
limit: usize,
) -> crate::Result<Vec<Chunk>> {
debug!("Searching for similar chunks (limit: {})", limit);
// SurrealDB vector similarity search
// For now, return recent chunks with embeddings
// TODO: Implement proper vector similarity when SurrealDB supports it
let query = "SELECT * FROM rlm_chunks WHERE embedding != NONE ORDER BY created_at DESC \
LIMIT $limit";
let mut response = self
.db
.query(query)
.bind(("limit", limit as i64))
.await
.map_err(|e| {
error!("Failed to search by embedding: {}", e);
STORAGE_OPERATIONS
.with_label_values(&["search_embedding", "error"])
.inc();
RLMError::DatabaseError(Box::new(e))
})?;
let results: Vec<Chunk> = response.take(0).map_err(|e| {
error!("Failed to parse embedding search results: {}", e);
RLMError::DatabaseError(Box::new(e))
})?;
STORAGE_OPERATIONS
.with_label_values(&["search_embedding", "success"])
.inc();
// Filter and rank by cosine similarity (in-memory for now)
let ranked = self.rank_by_similarity(&results, embedding, limit);
Ok(ranked)
}
async fn save_buffer(&self, buffer: Buffer) -> crate::Result<()> {
debug!("Saving buffer {}", buffer.buffer_id);
let query = "CREATE rlm_buffers SET buffer_id = $buffer_id, content = $content, metadata \
= $metadata, expires_at = $expires_at, created_at = $created_at";
let result = self
.db
.query(query)
.bind(("buffer_id", buffer.buffer_id.clone()))
.bind(("content", buffer.content.clone()))
.bind(("metadata", buffer.metadata.clone()))
.bind(("expires_at", buffer.expires_at.clone()))
.bind(("created_at", buffer.created_at.clone()))
.await;
match result {
Ok(_) => {
STORAGE_OPERATIONS
.with_label_values(&["save_buffer", "success"])
.inc();
Ok(())
}
Err(e) => {
error!("Failed to save buffer {}: {}", buffer.buffer_id, e);
STORAGE_OPERATIONS
.with_label_values(&["save_buffer", "error"])
.inc();
Err(RLMError::DatabaseError(Box::new(e)))
}
}
}
async fn get_buffer(&self, buffer_id: &str) -> crate::Result<Option<Buffer>> {
debug!("Fetching buffer {}", buffer_id);
let query = "SELECT * FROM rlm_buffers WHERE buffer_id = $buffer_id LIMIT 1";
let mut response = self
.db
.query(query)
.bind(("buffer_id", buffer_id.to_string()))
.await
.map_err(|e| {
error!("Failed to fetch buffer {}: {}", buffer_id, e);
STORAGE_OPERATIONS
.with_label_values(&["get_buffer", "error"])
.inc();
RLMError::DatabaseError(Box::new(e))
})?;
let results: Vec<Buffer> = response.take(0).map_err(|e| {
error!("Failed to parse buffer {}: {}", buffer_id, e);
RLMError::DatabaseError(Box::new(e))
})?;
STORAGE_OPERATIONS
.with_label_values(&["get_buffer", "success"])
.inc();
Ok(results.into_iter().next())
}
async fn cleanup_expired_buffers(&self) -> crate::Result<u64> {
debug!("Cleaning up expired buffers");
let now = Utc::now().to_rfc3339();
let query = "DELETE FROM rlm_buffers WHERE expires_at != NONE AND expires_at < $now";
let mut response = self.db.query(query).bind(("now", now)).await.map_err(|e| {
error!("Failed to cleanup expired buffers: {}", e);
STORAGE_OPERATIONS
.with_label_values(&["cleanup_buffers", "error"])
.inc();
RLMError::DatabaseError(Box::new(e))
})?;
// SurrealDB 2.x doesn't return delete count easily
let _: Vec<serde_json::Value> = response.take(0).unwrap_or_default();
STORAGE_OPERATIONS
.with_label_values(&["cleanup_buffers", "success"])
.inc();
Ok(0)
}
async fn save_execution(&self, execution: ExecutionHistory) -> crate::Result<()> {
debug!(
"Saving execution {} for document {}",
execution.execution_id, execution.doc_id
);
let query = "CREATE rlm_executions SET execution_id = $execution_id, doc_id = $doc_id, \
query = $query, chunks_used = $chunks_used, result = $result, duration_ms = \
$duration_ms, cost_cents = $cost_cents, provider = $provider, success = \
$success, error_message = $error_message, metadata = $metadata, created_at = \
$created_at, executed_at = $executed_at";
let result = self
.db
.query(query)
.bind(("execution_id", execution.execution_id.clone()))
.bind(("doc_id", execution.doc_id.clone()))
.bind(("query", execution.query.clone()))
.bind(("chunks_used", execution.chunks_used.clone()))
.bind(("result", execution.result.clone()))
.bind(("duration_ms", execution.duration_ms as i64))
.bind(("cost_cents", execution.cost_cents))
.bind(("provider", execution.provider.clone()))
.bind(("success", execution.success))
.bind(("error_message", execution.error_message.clone()))
.bind(("metadata", execution.metadata.clone()))
.bind(("created_at", execution.created_at.clone()))
.bind(("executed_at", execution.executed_at.clone()))
.await;
match result {
Ok(_) => {
STORAGE_OPERATIONS
.with_label_values(&["save_execution", "success"])
.inc();
Ok(())
}
Err(e) => {
error!("Failed to save execution {}: {}", execution.execution_id, e);
STORAGE_OPERATIONS
.with_label_values(&["save_execution", "error"])
.inc();
Err(RLMError::DatabaseError(Box::new(e)))
}
}
}
async fn get_executions(
&self,
doc_id: &str,
limit: usize,
) -> crate::Result<Vec<ExecutionHistory>> {
debug!(
"Fetching executions for document {} (limit: {})",
doc_id, limit
);
let query = "SELECT * FROM rlm_executions WHERE doc_id = $doc_id ORDER BY executed_at \
DESC LIMIT $limit";
let mut response = self
.db
.query(query)
.bind(("doc_id", doc_id.to_string()))
.bind(("limit", limit as i64))
.await
.map_err(|e| {
error!("Failed to fetch executions for doc {}: {}", doc_id, e);
STORAGE_OPERATIONS
.with_label_values(&["get_executions", "error"])
.inc();
RLMError::DatabaseError(Box::new(e))
})?;
let results: Vec<ExecutionHistory> = response.take(0).map_err(|e| {
error!("Failed to parse executions for doc {}: {}", doc_id, e);
RLMError::DatabaseError(Box::new(e))
})?;
STORAGE_OPERATIONS
.with_label_values(&["get_executions", "success"])
.inc();
Ok(results)
}
async fn delete_chunks(&self, doc_id: &str) -> crate::Result<u64> {
debug!("Deleting chunks for document {}", doc_id);
let query = "DELETE FROM rlm_chunks WHERE doc_id = $doc_id";
let mut response = self
.db
.query(query)
.bind(("doc_id", doc_id.to_string()))
.await
.map_err(|e| {
error!("Failed to delete chunks for doc {}: {}", doc_id, e);
STORAGE_OPERATIONS
.with_label_values(&["delete_chunks", "error"])
.inc();
RLMError::DatabaseError(Box::new(e))
})?;
// SurrealDB 2.x doesn't return delete count easily
let _: Vec<serde_json::Value> = response.take(0).unwrap_or_default();
STORAGE_OPERATIONS
.with_label_values(&["delete_chunks", "success"])
.inc();
Ok(0)
}
}
impl SurrealDBStorage {
/// Rank chunks by cosine similarity to query embedding (in-memory)
fn rank_by_similarity(
&self,
chunks: &[Chunk],
query_embedding: &[f32],
limit: usize,
) -> Vec<Chunk> {
let mut scored: Vec<(f32, Chunk)> = chunks
.iter()
.filter_map(|chunk| {
if let Some(ref embedding) = chunk.embedding {
let similarity = cosine_similarity(embedding, query_embedding);
Some((similarity, chunk.clone()))
} else {
None
}
})
.collect();
// Sort by similarity descending
scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
// Take top N
scored
.into_iter()
.take(limit)
.map(|(_, chunk)| chunk)
.collect()
}
}
/// Cosine similarity between two vectors
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let magnitude_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let magnitude_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if magnitude_a == 0.0 || magnitude_b == 0.0 {
return 0.0;
}
dot_product / (magnitude_a * magnitude_b)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
let c = vec![1.0, 0.0, 0.0];
let d = vec![0.0, 1.0, 0.0];
assert!((cosine_similarity(&c, &d) - 0.0).abs() < 0.001);
let e = vec![1.0, 1.0, 0.0];
let f = vec![1.0, 0.0, 0.0];
let similarity = cosine_similarity(&e, &f);
assert!(similarity > 0.7 && similarity < 0.8);
}
#[test]
fn test_cosine_similarity_edge_cases() {
assert_eq!(cosine_similarity(&[], &[]), 0.0);
assert_eq!(cosine_similarity(&[1.0], &[1.0, 2.0]), 0.0);
assert_eq!(cosine_similarity(&[0.0, 0.0], &[1.0, 1.0]), 0.0);
}
}