477 lines
15 KiB
Rust
477 lines
15 KiB
Rust
//! SurrealDB connection management and initialization
|
|
|
|
use std::collections::HashMap;
|
|
|
|
use serde::{Deserialize, Serialize};
|
|
use surrealdb::types::SurrealValue;
|
|
use surrealdb::engine::any::{connect, Any};
|
|
use surrealdb::Surreal;
|
|
|
|
use crate::config::VectorDbConfig;
|
|
use crate::embeddings::EmbeddedDocument;
|
|
use crate::error::Result;
|
|
|
|
/// SurrealDB connection manager for RAG system
|
|
pub struct DbConnection {
|
|
client: Surreal<Any>,
|
|
config: VectorDbConfig,
|
|
}
|
|
|
|
impl DbConnection {
|
|
/// Create a new database connection
|
|
pub async fn new(config: VectorDbConfig) -> Result<Self> {
|
|
// Parse SurrealDB URL
|
|
let url = if config.url == "memory" {
|
|
"mem://".to_string()
|
|
} else {
|
|
config.url.clone()
|
|
};
|
|
|
|
// Create connection using connect() function
|
|
let client = connect(&url)
|
|
.await
|
|
.map_err(|e| crate::error::RagError::surrealdb(e.to_string()))?;
|
|
|
|
// Select namespace and database
|
|
client
|
|
.use_ns(&config.namespace)
|
|
.use_db(&config.database)
|
|
.await
|
|
.map_err(|e| crate::error::RagError::surrealdb(e.to_string()))?;
|
|
|
|
tracing::info!(
|
|
"Connected to SurrealDB: {}/{}/{}",
|
|
url,
|
|
config.namespace,
|
|
config.database
|
|
);
|
|
|
|
Ok(Self { client, config })
|
|
}
|
|
|
|
/// Initialize RAG schema (create tables, indexes, etc.)
|
|
pub async fn initialize_schema(&self) -> Result<()> {
|
|
tracing::info!("Initializing RAG schema in SurrealDB");
|
|
|
|
// Create documents table
|
|
self.create_documents_table().await?;
|
|
|
|
// Create deployments table
|
|
self.create_deployments_table().await?;
|
|
|
|
// Create relationship tables
|
|
self.create_relationship_tables().await?;
|
|
|
|
// Create metadata tables
|
|
self.create_metadata_tables().await?;
|
|
|
|
tracing::info!("RAG schema initialized successfully");
|
|
Ok(())
|
|
}
|
|
|
|
/// Create documents table
|
|
async fn create_documents_table(&self) -> Result<()> {
|
|
let query = format!(
|
|
r#"
|
|
DEFINE TABLE {table} SCHEMAFULL;
|
|
DEFINE FIELD id ON {table} TYPE string;
|
|
DEFINE FIELD source_path ON {table} TYPE string;
|
|
DEFINE FIELD doc_type ON {table} TYPE string;
|
|
DEFINE FIELD category ON {table} TYPE string;
|
|
DEFINE FIELD content ON {table} TYPE string;
|
|
DEFINE FIELD embedding ON {table} TYPE array<float>;
|
|
DEFINE FIELD metadata ON {table} TYPE object;
|
|
DEFINE FIELD git_commit ON {table} TYPE string;
|
|
DEFINE FIELD indexed_at ON {table} TYPE datetime DEFAULT time::now();
|
|
DEFINE FIELD updated_at ON {table} TYPE datetime DEFAULT time::now();
|
|
DEFINE INDEX idx_embedding ON {table} FIELDS embedding HNSW
|
|
DIMENSION {dimension};
|
|
"#,
|
|
table = self.config.documents_table,
|
|
dimension = 1536
|
|
);
|
|
|
|
self.client
|
|
.query(&query)
|
|
.await
|
|
.map_err(|e| crate::error::RagError::surrealdb(e.to_string()))?;
|
|
|
|
tracing::info!("Created documents table: {}", self.config.documents_table);
|
|
Ok(())
|
|
}
|
|
|
|
/// Create deployments table
|
|
async fn create_deployments_table(&self) -> Result<()> {
|
|
let query = format!(
|
|
r#"
|
|
DEFINE TABLE {table} SCHEMAFULL;
|
|
DEFINE FIELD id ON {table} TYPE string;
|
|
DEFINE FIELD workspace ON {table} TYPE string;
|
|
DEFINE FIELD infrastructure ON {table} TYPE string;
|
|
DEFINE FIELD event_type ON {table} TYPE string;
|
|
DEFINE FIELD status ON {table} TYPE string;
|
|
DEFINE FIELD resource_name ON {table} TYPE string;
|
|
DEFINE FIELD provider ON {table} TYPE string;
|
|
DEFINE FIELD duration_ms ON {table} TYPE int;
|
|
DEFINE FIELD config_snapshot ON {table} TYPE object;
|
|
DEFINE FIELD embedding ON {table} TYPE array<float>;
|
|
DEFINE FIELD error ON {table} TYPE object;
|
|
DEFINE FIELD timestamp ON {table} TYPE datetime DEFAULT time::now();
|
|
DEFINE INDEX idx_workspace ON {table} FIELDS workspace;
|
|
DEFINE INDEX idx_status ON {table} FIELDS status;
|
|
DEFINE INDEX idx_embedding ON {table} FIELDS embedding HNSW
|
|
DIMENSION {dimension};
|
|
"#,
|
|
table = self.config.deployments_table,
|
|
dimension = 1536
|
|
);
|
|
|
|
self.client
|
|
.query(&query)
|
|
.await
|
|
.map_err(|e| crate::error::RagError::surrealdb(e.to_string()))?;
|
|
|
|
tracing::info!(
|
|
"Created deployments table: {}",
|
|
self.config.deployments_table
|
|
);
|
|
Ok(())
|
|
}
|
|
|
|
/// Create relationship tables
|
|
async fn create_relationship_tables(&self) -> Result<()> {
|
|
// Document relationships
|
|
let query = r#"
|
|
DEFINE TABLE doc_relates_to TYPE RELATION
|
|
FROM documents
|
|
TO documents;
|
|
DEFINE FIELD relevance_score ON doc_relates_to TYPE float DEFAULT 0.0;
|
|
"#;
|
|
|
|
self.client
|
|
.query(query)
|
|
.await
|
|
.map_err(|e| crate::error::RagError::surrealdb(e.to_string()))?;
|
|
|
|
// Deployment relationships
|
|
let query = r#"
|
|
DEFINE TABLE deployment_depends_on TYPE RELATION
|
|
FROM deployments
|
|
TO deployments;
|
|
DEFINE FIELD dependency_type ON deployment_depends_on TYPE string;
|
|
"#;
|
|
|
|
self.client
|
|
.query(query)
|
|
.await
|
|
.map_err(|e| crate::error::RagError::surrealdb(e.to_string()))?;
|
|
|
|
tracing::info!("Created relationship tables");
|
|
Ok(())
|
|
}
|
|
|
|
/// Create metadata tables
|
|
async fn create_metadata_tables(&self) -> Result<()> {
|
|
let query = r#"
|
|
DEFINE TABLE rag_index_metadata SCHEMAFULL;
|
|
DEFINE FIELD id ON rag_index_metadata TYPE string;
|
|
DEFINE FIELD last_reindex ON rag_index_metadata TYPE datetime;
|
|
DEFINE FIELD total_documents ON rag_index_metadata TYPE int DEFAULT 0;
|
|
DEFINE FIELD total_deployments ON rag_index_metadata TYPE int DEFAULT 0;
|
|
DEFINE FIELD reindex_status ON rag_index_metadata TYPE string DEFAULT "idle";
|
|
"#;
|
|
|
|
self.client
|
|
.query(query)
|
|
.await
|
|
.map_err(|e| crate::error::RagError::surrealdb(e.to_string()))?;
|
|
|
|
tracing::info!("Created metadata tables");
|
|
Ok(())
|
|
}
|
|
|
|
/// Get SurrealDB client
|
|
pub fn client(&self) -> &Surreal<Any> {
|
|
&self.client
|
|
}
|
|
|
|
/// Get configuration
|
|
pub fn config(&self) -> &VectorDbConfig {
|
|
&self.config
|
|
}
|
|
|
|
/// Health check
|
|
pub async fn health_check(&self) -> Result<bool> {
|
|
match self
|
|
.client
|
|
.query("SELECT * FROM rag_index_metadata LIMIT 1")
|
|
.await
|
|
{
|
|
Ok(_) => {
|
|
tracing::debug!("SurrealDB health check passed");
|
|
Ok(true)
|
|
}
|
|
Err(e) => {
|
|
tracing::warn!("SurrealDB health check failed: {}", e);
|
|
Ok(false)
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Store a single embedded document in the database
|
|
pub async fn store_document(&self, doc: &EmbeddedDocument) -> Result<()> {
|
|
let table = &self.config.documents_table;
|
|
|
|
// Build metadata map
|
|
let metadata: HashMap<String, String> = doc
|
|
.metadata
|
|
.iter()
|
|
.map(|(k, v)| (k.clone(), v.clone()))
|
|
.collect();
|
|
|
|
// Use raw insert with JSON binding
|
|
let query = serde_json::json!({
|
|
"id": &doc.id,
|
|
"source_path": &doc.source_path,
|
|
"doc_type": &doc.doc_type,
|
|
"content": &doc.content,
|
|
"embedding": &doc.embedding,
|
|
"metadata": metadata,
|
|
});
|
|
|
|
self.client
|
|
.query(format!("CREATE {} CONTENT $content;", table))
|
|
.bind(("content", query))
|
|
.await
|
|
.map_err(|e| crate::error::RagError::surrealdb(e.to_string()))?;
|
|
|
|
tracing::debug!("Stored document: {}", doc.id);
|
|
Ok(())
|
|
}
|
|
|
|
/// Store multiple embedded documents in batch
|
|
pub async fn store_documents(&self, docs: &[EmbeddedDocument]) -> Result<usize> {
|
|
let mut count = 0;
|
|
|
|
for doc in docs {
|
|
if let Err(e) = self.store_document(doc).await {
|
|
tracing::warn!("Failed to store document {}: {}", doc.id, e);
|
|
// Continue with next document on error
|
|
continue;
|
|
}
|
|
count += 1;
|
|
}
|
|
|
|
tracing::info!("Stored {} documents", count);
|
|
Ok(count)
|
|
}
|
|
|
|
/// Store a deployment event
|
|
pub async fn store_deployment_event(
|
|
&self,
|
|
workspace: &str,
|
|
infrastructure: &str,
|
|
event_type: &str,
|
|
status: &str,
|
|
resource_name: &str,
|
|
provider: &str,
|
|
) -> Result<String> {
|
|
let table = &self.config.deployments_table;
|
|
|
|
let event = serde_json::json!({
|
|
"workspace": workspace,
|
|
"infrastructure": infrastructure,
|
|
"event_type": event_type,
|
|
"status": status,
|
|
"resource_name": resource_name,
|
|
"provider": provider,
|
|
"timestamp": chrono::Utc::now().to_rfc3339(),
|
|
});
|
|
|
|
self.client
|
|
.query(format!("CREATE {} CONTENT $event;", table))
|
|
.bind(("event", event))
|
|
.await
|
|
.map_err(|e| crate::error::RagError::surrealdb(e.to_string()))?;
|
|
|
|
tracing::info!(
|
|
"Stored deployment event: {} {} {}",
|
|
event_type,
|
|
resource_name,
|
|
status
|
|
);
|
|
|
|
let id = format!("{}/{}", table, "auto_generated");
|
|
Ok(id)
|
|
}
|
|
|
|
/// Get a document by ID
|
|
pub async fn get_document(&self, doc_id: &str) -> Result<Option<DocumentRecord>> {
|
|
let query = format!(
|
|
"SELECT * FROM {} WHERE id = $id;",
|
|
self.config.documents_table
|
|
);
|
|
|
|
let bindings = serde_json::json!({
|
|
"id": doc_id.to_string(),
|
|
});
|
|
|
|
let mut response = self
|
|
.client
|
|
.query(&query)
|
|
.bind(bindings)
|
|
.await
|
|
.map_err(|e| crate::error::RagError::surrealdb(e.to_string()))?;
|
|
|
|
let records: Vec<DocumentRecord> = response
|
|
.take(0)
|
|
.map_err(|e| crate::error::RagError::surrealdb(e.to_string()))?;
|
|
|
|
Ok(records.into_iter().next())
|
|
}
|
|
|
|
/// Search for similar documents using vector similarity
|
|
pub async fn search_similar(
|
|
&self,
|
|
embedding: &[f32],
|
|
limit: usize,
|
|
threshold: f32,
|
|
) -> Result<Vec<DocumentRecord>> {
|
|
// SurrealDB HNSW search query
|
|
let query = format!(
|
|
r#"
|
|
SELECT *,
|
|
vector::similarity::cosine(embedding, $embedding) AS similarity
|
|
FROM {}
|
|
WHERE vector::similarity::cosine(embedding, $embedding) > $threshold
|
|
ORDER BY similarity DESC
|
|
LIMIT $limit;
|
|
"#,
|
|
self.config.documents_table
|
|
);
|
|
|
|
let embedding_vec: Vec<f32> = embedding.to_vec();
|
|
let bindings = serde_json::json!({
|
|
"embedding": embedding_vec,
|
|
"threshold": threshold,
|
|
"limit": limit,
|
|
});
|
|
|
|
let mut response = self
|
|
.client
|
|
.query(&query)
|
|
.bind(bindings)
|
|
.await
|
|
.map_err(|e| crate::error::RagError::surrealdb(e.to_string()))?;
|
|
|
|
let records: Vec<DocumentRecord> = response
|
|
.take(0)
|
|
.map_err(|e| crate::error::RagError::surrealdb(e.to_string()))?;
|
|
|
|
tracing::debug!("Found {} similar documents", records.len());
|
|
Ok(records)
|
|
}
|
|
|
|
/// Delete a document by ID
|
|
pub async fn delete_document(&self, doc_id: &str) -> Result<()> {
|
|
let query = format!(
|
|
"DELETE FROM {} WHERE id = $id;",
|
|
self.config.documents_table
|
|
);
|
|
|
|
let bindings = serde_json::json!({
|
|
"id": doc_id.to_string(),
|
|
});
|
|
|
|
self.client
|
|
.query(&query)
|
|
.bind(bindings)
|
|
.await
|
|
.map_err(|e| crate::error::RagError::surrealdb(e.to_string()))?;
|
|
|
|
tracing::info!("Deleted document: {}", doc_id);
|
|
Ok(())
|
|
}
|
|
|
|
/// Get RAG system statistics
|
|
pub async fn get_statistics(&self) -> Result<RagStatistics> {
|
|
let doc_count_query = format!(
|
|
"SELECT count() as count FROM {} GROUP ALL;",
|
|
self.config.documents_table
|
|
);
|
|
let deploy_count_query = format!(
|
|
"SELECT count() as count FROM {} GROUP ALL;",
|
|
self.config.deployments_table
|
|
);
|
|
|
|
let mut doc_response = self
|
|
.client
|
|
.query(&doc_count_query)
|
|
.await
|
|
.map_err(|e| crate::error::RagError::surrealdb(e.to_string()))?;
|
|
|
|
let mut deploy_response = self
|
|
.client
|
|
.query(&deploy_count_query)
|
|
.await
|
|
.map_err(|e| crate::error::RagError::surrealdb(e.to_string()))?;
|
|
|
|
let doc_count: Vec<CountResult> = doc_response.take(0).unwrap_or_default();
|
|
|
|
let deploy_count: Vec<CountResult> = deploy_response.take(0).unwrap_or_default();
|
|
|
|
let documents = doc_count.first().map(|r| r.count).unwrap_or(0);
|
|
let deployments = deploy_count.first().map(|r| r.count).unwrap_or(0);
|
|
|
|
Ok(RagStatistics {
|
|
total_documents: documents,
|
|
total_deployments: deployments,
|
|
last_updated: chrono::Utc::now().to_rfc3339(),
|
|
})
|
|
}
|
|
}
|
|
|
|
/// Document record stored in SurrealDB
|
|
#[derive(Debug, Clone, Serialize, Deserialize, SurrealValue)]
|
|
#[surreal(crate = "surrealdb::types")]
|
|
pub struct DocumentRecord {
|
|
pub id: String,
|
|
pub source_path: String,
|
|
pub doc_type: String,
|
|
pub content: String,
|
|
pub embedding: Vec<f32>,
|
|
pub metadata: HashMap<String, String>,
|
|
pub indexed_at: String,
|
|
pub updated_at: String,
|
|
}
|
|
|
|
/// Count result for statistics queries
|
|
#[derive(Debug, Deserialize, SurrealValue)]
|
|
#[surreal(crate = "surrealdb::types")]
|
|
struct CountResult {
|
|
count: i32,
|
|
}
|
|
|
|
/// RAG system statistics
|
|
#[derive(Debug, Clone, Serialize)]
|
|
pub struct RagStatistics {
|
|
pub total_documents: i32,
|
|
pub total_deployments: i32,
|
|
pub last_updated: String,
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[tokio::test]
|
|
async fn test_memory_db_connection() {
|
|
let config = VectorDbConfig {
|
|
url: "memory".to_string(),
|
|
..Default::default()
|
|
};
|
|
|
|
let result = DbConnection::new(config).await;
|
|
assert!(result.is_ok());
|
|
}
|
|
}
|