prvng_platform/crates/rag/src/db.rs

477 lines
15 KiB
Rust

//! SurrealDB connection management and initialization
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use surrealdb::types::SurrealValue;
use surrealdb::engine::any::{connect, Any};
use surrealdb::Surreal;
use crate::config::VectorDbConfig;
use crate::embeddings::EmbeddedDocument;
use crate::error::Result;
/// SurrealDB connection manager for RAG system
pub struct DbConnection {
client: Surreal<Any>,
config: VectorDbConfig,
}
impl DbConnection {
/// Create a new database connection
pub async fn new(config: VectorDbConfig) -> Result<Self> {
// Parse SurrealDB URL
let url = if config.url == "memory" {
"mem://".to_string()
} else {
config.url.clone()
};
// Create connection using connect() function
let client = connect(&url)
.await
.map_err(|e| crate::error::RagError::surrealdb(e.to_string()))?;
// Select namespace and database
client
.use_ns(&config.namespace)
.use_db(&config.database)
.await
.map_err(|e| crate::error::RagError::surrealdb(e.to_string()))?;
tracing::info!(
"Connected to SurrealDB: {}/{}/{}",
url,
config.namespace,
config.database
);
Ok(Self { client, config })
}
/// Initialize RAG schema (create tables, indexes, etc.)
pub async fn initialize_schema(&self) -> Result<()> {
tracing::info!("Initializing RAG schema in SurrealDB");
// Create documents table
self.create_documents_table().await?;
// Create deployments table
self.create_deployments_table().await?;
// Create relationship tables
self.create_relationship_tables().await?;
// Create metadata tables
self.create_metadata_tables().await?;
tracing::info!("RAG schema initialized successfully");
Ok(())
}
/// Create documents table
async fn create_documents_table(&self) -> Result<()> {
let query = format!(
r#"
DEFINE TABLE {table} SCHEMAFULL;
DEFINE FIELD id ON {table} TYPE string;
DEFINE FIELD source_path ON {table} TYPE string;
DEFINE FIELD doc_type ON {table} TYPE string;
DEFINE FIELD category ON {table} TYPE string;
DEFINE FIELD content ON {table} TYPE string;
DEFINE FIELD embedding ON {table} TYPE array<float>;
DEFINE FIELD metadata ON {table} TYPE object;
DEFINE FIELD git_commit ON {table} TYPE string;
DEFINE FIELD indexed_at ON {table} TYPE datetime DEFAULT time::now();
DEFINE FIELD updated_at ON {table} TYPE datetime DEFAULT time::now();
DEFINE INDEX idx_embedding ON {table} FIELDS embedding HNSW
DIMENSION {dimension};
"#,
table = self.config.documents_table,
dimension = 1536
);
self.client
.query(&query)
.await
.map_err(|e| crate::error::RagError::surrealdb(e.to_string()))?;
tracing::info!("Created documents table: {}", self.config.documents_table);
Ok(())
}
/// Create deployments table
async fn create_deployments_table(&self) -> Result<()> {
let query = format!(
r#"
DEFINE TABLE {table} SCHEMAFULL;
DEFINE FIELD id ON {table} TYPE string;
DEFINE FIELD workspace ON {table} TYPE string;
DEFINE FIELD infrastructure ON {table} TYPE string;
DEFINE FIELD event_type ON {table} TYPE string;
DEFINE FIELD status ON {table} TYPE string;
DEFINE FIELD resource_name ON {table} TYPE string;
DEFINE FIELD provider ON {table} TYPE string;
DEFINE FIELD duration_ms ON {table} TYPE int;
DEFINE FIELD config_snapshot ON {table} TYPE object;
DEFINE FIELD embedding ON {table} TYPE array<float>;
DEFINE FIELD error ON {table} TYPE object;
DEFINE FIELD timestamp ON {table} TYPE datetime DEFAULT time::now();
DEFINE INDEX idx_workspace ON {table} FIELDS workspace;
DEFINE INDEX idx_status ON {table} FIELDS status;
DEFINE INDEX idx_embedding ON {table} FIELDS embedding HNSW
DIMENSION {dimension};
"#,
table = self.config.deployments_table,
dimension = 1536
);
self.client
.query(&query)
.await
.map_err(|e| crate::error::RagError::surrealdb(e.to_string()))?;
tracing::info!(
"Created deployments table: {}",
self.config.deployments_table
);
Ok(())
}
/// Create relationship tables
async fn create_relationship_tables(&self) -> Result<()> {
// Document relationships
let query = r#"
DEFINE TABLE doc_relates_to TYPE RELATION
FROM documents
TO documents;
DEFINE FIELD relevance_score ON doc_relates_to TYPE float DEFAULT 0.0;
"#;
self.client
.query(query)
.await
.map_err(|e| crate::error::RagError::surrealdb(e.to_string()))?;
// Deployment relationships
let query = r#"
DEFINE TABLE deployment_depends_on TYPE RELATION
FROM deployments
TO deployments;
DEFINE FIELD dependency_type ON deployment_depends_on TYPE string;
"#;
self.client
.query(query)
.await
.map_err(|e| crate::error::RagError::surrealdb(e.to_string()))?;
tracing::info!("Created relationship tables");
Ok(())
}
/// Create metadata tables
async fn create_metadata_tables(&self) -> Result<()> {
let query = r#"
DEFINE TABLE rag_index_metadata SCHEMAFULL;
DEFINE FIELD id ON rag_index_metadata TYPE string;
DEFINE FIELD last_reindex ON rag_index_metadata TYPE datetime;
DEFINE FIELD total_documents ON rag_index_metadata TYPE int DEFAULT 0;
DEFINE FIELD total_deployments ON rag_index_metadata TYPE int DEFAULT 0;
DEFINE FIELD reindex_status ON rag_index_metadata TYPE string DEFAULT "idle";
"#;
self.client
.query(query)
.await
.map_err(|e| crate::error::RagError::surrealdb(e.to_string()))?;
tracing::info!("Created metadata tables");
Ok(())
}
/// Get SurrealDB client
pub fn client(&self) -> &Surreal<Any> {
&self.client
}
/// Get configuration
pub fn config(&self) -> &VectorDbConfig {
&self.config
}
/// Health check
pub async fn health_check(&self) -> Result<bool> {
match self
.client
.query("SELECT * FROM rag_index_metadata LIMIT 1")
.await
{
Ok(_) => {
tracing::debug!("SurrealDB health check passed");
Ok(true)
}
Err(e) => {
tracing::warn!("SurrealDB health check failed: {}", e);
Ok(false)
}
}
}
/// Store a single embedded document in the database
pub async fn store_document(&self, doc: &EmbeddedDocument) -> Result<()> {
let table = &self.config.documents_table;
// Build metadata map
let metadata: HashMap<String, String> = doc
.metadata
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
// Use raw insert with JSON binding
let query = serde_json::json!({
"id": &doc.id,
"source_path": &doc.source_path,
"doc_type": &doc.doc_type,
"content": &doc.content,
"embedding": &doc.embedding,
"metadata": metadata,
});
self.client
.query(format!("CREATE {} CONTENT $content;", table))
.bind(("content", query))
.await
.map_err(|e| crate::error::RagError::surrealdb(e.to_string()))?;
tracing::debug!("Stored document: {}", doc.id);
Ok(())
}
/// Store multiple embedded documents in batch
pub async fn store_documents(&self, docs: &[EmbeddedDocument]) -> Result<usize> {
let mut count = 0;
for doc in docs {
if let Err(e) = self.store_document(doc).await {
tracing::warn!("Failed to store document {}: {}", doc.id, e);
// Continue with next document on error
continue;
}
count += 1;
}
tracing::info!("Stored {} documents", count);
Ok(count)
}
/// Store a deployment event
pub async fn store_deployment_event(
&self,
workspace: &str,
infrastructure: &str,
event_type: &str,
status: &str,
resource_name: &str,
provider: &str,
) -> Result<String> {
let table = &self.config.deployments_table;
let event = serde_json::json!({
"workspace": workspace,
"infrastructure": infrastructure,
"event_type": event_type,
"status": status,
"resource_name": resource_name,
"provider": provider,
"timestamp": chrono::Utc::now().to_rfc3339(),
});
self.client
.query(format!("CREATE {} CONTENT $event;", table))
.bind(("event", event))
.await
.map_err(|e| crate::error::RagError::surrealdb(e.to_string()))?;
tracing::info!(
"Stored deployment event: {} {} {}",
event_type,
resource_name,
status
);
let id = format!("{}/{}", table, "auto_generated");
Ok(id)
}
/// Get a document by ID
pub async fn get_document(&self, doc_id: &str) -> Result<Option<DocumentRecord>> {
let query = format!(
"SELECT * FROM {} WHERE id = $id;",
self.config.documents_table
);
let bindings = serde_json::json!({
"id": doc_id.to_string(),
});
let mut response = self
.client
.query(&query)
.bind(bindings)
.await
.map_err(|e| crate::error::RagError::surrealdb(e.to_string()))?;
let records: Vec<DocumentRecord> = response
.take(0)
.map_err(|e| crate::error::RagError::surrealdb(e.to_string()))?;
Ok(records.into_iter().next())
}
/// Search for similar documents using vector similarity
pub async fn search_similar(
&self,
embedding: &[f32],
limit: usize,
threshold: f32,
) -> Result<Vec<DocumentRecord>> {
// SurrealDB HNSW search query
let query = format!(
r#"
SELECT *,
vector::similarity::cosine(embedding, $embedding) AS similarity
FROM {}
WHERE vector::similarity::cosine(embedding, $embedding) > $threshold
ORDER BY similarity DESC
LIMIT $limit;
"#,
self.config.documents_table
);
let embedding_vec: Vec<f32> = embedding.to_vec();
let bindings = serde_json::json!({
"embedding": embedding_vec,
"threshold": threshold,
"limit": limit,
});
let mut response = self
.client
.query(&query)
.bind(bindings)
.await
.map_err(|e| crate::error::RagError::surrealdb(e.to_string()))?;
let records: Vec<DocumentRecord> = response
.take(0)
.map_err(|e| crate::error::RagError::surrealdb(e.to_string()))?;
tracing::debug!("Found {} similar documents", records.len());
Ok(records)
}
/// Delete a document by ID
pub async fn delete_document(&self, doc_id: &str) -> Result<()> {
let query = format!(
"DELETE FROM {} WHERE id = $id;",
self.config.documents_table
);
let bindings = serde_json::json!({
"id": doc_id.to_string(),
});
self.client
.query(&query)
.bind(bindings)
.await
.map_err(|e| crate::error::RagError::surrealdb(e.to_string()))?;
tracing::info!("Deleted document: {}", doc_id);
Ok(())
}
/// Get RAG system statistics
pub async fn get_statistics(&self) -> Result<RagStatistics> {
let doc_count_query = format!(
"SELECT count() as count FROM {} GROUP ALL;",
self.config.documents_table
);
let deploy_count_query = format!(
"SELECT count() as count FROM {} GROUP ALL;",
self.config.deployments_table
);
let mut doc_response = self
.client
.query(&doc_count_query)
.await
.map_err(|e| crate::error::RagError::surrealdb(e.to_string()))?;
let mut deploy_response = self
.client
.query(&deploy_count_query)
.await
.map_err(|e| crate::error::RagError::surrealdb(e.to_string()))?;
let doc_count: Vec<CountResult> = doc_response.take(0).unwrap_or_default();
let deploy_count: Vec<CountResult> = deploy_response.take(0).unwrap_or_default();
let documents = doc_count.first().map(|r| r.count).unwrap_or(0);
let deployments = deploy_count.first().map(|r| r.count).unwrap_or(0);
Ok(RagStatistics {
total_documents: documents,
total_deployments: deployments,
last_updated: chrono::Utc::now().to_rfc3339(),
})
}
}
/// Document record stored in SurrealDB
#[derive(Debug, Clone, Serialize, Deserialize, SurrealValue)]
#[surreal(crate = "surrealdb::types")]
pub struct DocumentRecord {
pub id: String,
pub source_path: String,
pub doc_type: String,
pub content: String,
pub embedding: Vec<f32>,
pub metadata: HashMap<String, String>,
pub indexed_at: String,
pub updated_at: String,
}
/// Count result for statistics queries
#[derive(Debug, Deserialize, SurrealValue)]
#[surreal(crate = "surrealdb::types")]
struct CountResult {
count: i32,
}
/// RAG system statistics
#[derive(Debug, Clone, Serialize)]
pub struct RagStatistics {
pub total_documents: i32,
pub total_deployments: i32,
pub last_updated: String,
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_memory_db_connection() {
let config = VectorDbConfig {
url: "memory".to_string(),
..Default::default()
};
let result = DbConnection::new(config).await;
assert!(result.is_ok());
}
}