use std::sync::Arc; use chrono::{Duration, Utc}; use dashmap::DashMap; use tracing::{debug, warn}; use crate::error::Result; use crate::models::*; /// Temporal Knowledge Graph for storing and querying agent execution history /// Phase 5.1: Uses embedding-based similarity for semantic matching pub struct TemporalKG { records: Arc>, profiles: Arc>, embedding_provider: Option>, embedding_cache: Arc>>, } impl TemporalKG { /// Create new temporal KG with in-memory storage pub async fn new(_db_url: &str, _user: &str, _pass: &str) -> Result { debug!("Initializing temporal knowledge graph"); Ok(Self { records: Arc::new(DashMap::new()), profiles: Arc::new(DashMap::new()), embedding_provider: None, embedding_cache: Arc::new(DashMap::new()), }) } /// Create temporal KG with embedding provider (Phase 5.1) pub async fn with_embeddings( _db_url: &str, _user: &str, _pass: &str, embedding_provider: Arc, ) -> Result { debug!( "Initializing temporal KG with embeddings ({})", embedding_provider.provider_name() ); Ok(Self { records: Arc::new(DashMap::new()), profiles: Arc::new(DashMap::new()), embedding_provider: Some(embedding_provider), embedding_cache: Arc::new(DashMap::new()), }) } /// Get or compute embedding for text (with caching) async fn get_or_embed(&self, text: &str) -> Result>> { if let Some(provider) = &self.embedding_provider { let cache_key = format!("{:x}", md5::compute(text.as_bytes())); if let Some(cached) = self.embedding_cache.get(&cache_key) { return Ok(Some(cached.clone())); } match provider.embed(text).await { Ok(embedding) => { self.embedding_cache.insert(cache_key, embedding.clone()); Ok(Some(embedding)) } Err(e) => { warn!("Failed to generate embedding: {}", e); Ok(None) // Fallback to Jaccard if embedding fails } } } else { Ok(None) } } /// Compute vector similarity using cosine distance fn compute_vector_similarity(vec_a: &[f32], vec_b: &[f32]) -> f64 { if vec_a.is_empty() || vec_b.is_empty() { return 0.0; } let dot_product: f32 = vec_a.iter().zip(vec_b).map(|(a, b)| a * b).sum(); let norm_a: f32 = vec_a.iter().map(|x| x * x).sum::().sqrt(); let norm_b: f32 = vec_b.iter().map(|x| x * x).sum::().sqrt(); if norm_a == 0.0 || norm_b == 0.0 { return 0.0; } (dot_product / (norm_a * norm_b)) as f64 } /// Record task execution for learning pub async fn record_execution(&self, record: ExecutionRecord) -> Result<()> { debug!("Recording execution: {}", record.id); self.records.insert(record.id.clone(), record); Ok(()) } /// Query similar tasks within 90 days (Phase 5.1: uses embeddings if /// available) pub async fn query_similar_tasks( &self, task_type: &str, description: &str, ) -> Result> { let now = Utc::now(); let cutoff = now - Duration::days(90); let threshold = 0.4; // Similarity threshold let query_embedding = self.get_or_embed(description).await.ok().flatten(); let mut similar_with_scores = Vec::new(); for entry in self.records.iter() { let record = entry.value(); if record.timestamp > cutoff && record.task_type == task_type { let similarity = if let Some(ref query_emb) = query_embedding { // Phase 5.1: Use vector embedding similarity if let Ok(Some(record_emb)) = self.get_or_embed(&record.description).await { Self::compute_vector_similarity(query_emb, &record_emb) } else { // Fallback to Jaccard if embedding fails calculate_similarity(description, &record.description) } } else { // Fallback to Jaccard if no embedding provider calculate_similarity(description, &record.description) }; if similarity >= threshold { similar_with_scores.push((record.clone(), similarity)); } } } // Sort by similarity descending similar_with_scores .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); Ok(similar_with_scores .into_iter() .take(5) .map(|(record, _)| record) .collect()) } /// Get recommendations from similar successful tasks (Phase 5.1: /// embedding-based) pub async fn get_recommendations( &self, task_type: &str, description: &str, ) -> Result> { let similar_tasks = self.query_similar_tasks(task_type, description).await?; let query_embedding = self.get_or_embed(description).await.ok().flatten(); let mut recommendations = Vec::new(); for task in similar_tasks { if task.success { let confidence = if let Some(ref query_emb) = query_embedding { if let Ok(Some(task_emb)) = self.get_or_embed(&task.description).await { Self::compute_vector_similarity(query_emb, &task_emb) } else { calculate_similarity(description, &task.description) } } else { calculate_similarity(description, &task.description) }; recommendations.push(Recommendation { source_record_id: task.id.clone(), source_agent_id: task.agent_id.clone(), solution: task.solution.clone().unwrap_or_default(), confidence, estimated_duration_ms: task.duration_ms, reasoning: format!( "Similar task '{}' succeeded with solution: {}", task.id, task.solution.clone().unwrap_or_else(|| "N/A".to_string()) ), }); } } Ok(recommendations) } /// Get agent expertise profile pub async fn get_agent_profile(&self, agent_id: &str) -> Result { let mut total_tasks = 0u64; let mut successful_tasks = 0u64; let mut task_types = std::collections::HashSet::new(); let mut durations = Vec::new(); for entry in self.records.iter() { let record = entry.value(); if record.agent_id == agent_id { total_tasks += 1; task_types.insert(record.task_type.clone()); durations.push(record.duration_ms); if record.success { successful_tasks += 1; } } } let avg_duration = if !durations.is_empty() { durations.iter().sum::() as f64 / durations.len() as f64 } else { 0.0 }; let expertise_score = if total_tasks > 0 { (successful_tasks as f64 / total_tasks as f64) * 100.0 } else { 0.0 }; // Return existing profile or create new one if let Some(profile) = self.profiles.get(agent_id) { return Ok(profile.clone()); } Ok(AgentProfile { agent_id: agent_id.to_string(), total_tasks, success_count: successful_tasks, avg_duration_ms: avg_duration, primary_task_types: task_types.into_iter().collect(), expertise_score, learning_curve: vec![], }) } /// Get knowledge graph statistics pub async fn get_statistics(&self) -> Result { let total_records = self.records.len() as u64; let successful = self.records.iter().filter(|e| e.value().success).count() as u64; let failed = total_records - successful; let mut avg_duration = 0.0; let mut total_duration = 0u64; let mut distinct_agents = std::collections::HashSet::new(); let mut task_types = std::collections::HashSet::new(); for entry in self.records.iter() { let record = entry.value(); total_duration += record.duration_ms; distinct_agents.insert(record.agent_id.clone()); task_types.insert(record.task_type.clone()); } if total_records > 0 { avg_duration = total_duration as f64 / total_records as f64; } Ok(GraphStatistics { total_records, total_successful: successful, total_failed: failed, success_rate: if total_records > 0 { successful as f64 / total_records as f64 } else { 0.0 }, avg_duration_ms: avg_duration, distinct_agents: distinct_agents.len() as u32, distinct_task_types: task_types.len() as u32, }) } /// Find causal relationships (error patterns) - Phase 5.1: embedding-based pub async fn find_causal_relationships( &self, cause_pattern: &str, ) -> Result> { let mut relationships = Vec::new(); let threshold = 0.5; let pattern_embedding = self.get_or_embed(cause_pattern).await.ok().flatten(); for entry in self.records.iter() { let record = entry.value(); if !record.success { if let Some(error) = &record.error { let similarity = if let Some(ref pattern_emb) = pattern_embedding { if let Ok(Some(error_emb)) = self.get_or_embed(error).await { Self::compute_vector_similarity(pattern_emb, &error_emb) } else { calculate_similarity(cause_pattern, error) } } else { calculate_similarity(cause_pattern, error) }; if similarity >= threshold { relationships.push(CausalRelationship { cause: error.clone(), effect: record .solution .clone() .unwrap_or_else(|| "unknown".to_string()), confidence: similarity, frequency: 1, }); } } } } // Deduplicate and count occurrences let mut deduped: std::collections::HashMap = std::collections::HashMap::new(); for rel in relationships { deduped .entry(rel.cause.clone()) .and_modify(|r| r.frequency += 1) .or_insert(rel); } Ok(deduped.into_values().collect()) } /// Check if embeddings are enabled pub fn has_embeddings(&self) -> bool { self.embedding_provider.is_some() } /// Get embedding provider name if available pub fn embedding_provider_name(&self) -> Option<&str> { self.embedding_provider.as_ref().map(|p| p.provider_name()) } /// Clear all data (for testing) #[cfg(test)] pub fn clear(&self) { self.records.clear(); self.profiles.clear(); self.embedding_cache.clear(); } } /// Calculate similarity between two texts using Jaccard coefficient fn calculate_similarity(text_a: &str, text_b: &str) -> f64 { let words_a: std::collections::HashSet<_> = text_a.split_whitespace().collect(); let words_b: std::collections::HashSet<_> = text_b.split_whitespace().collect(); if words_a.is_empty() && words_b.is_empty() { return 1.0; } let intersection = words_a.intersection(&words_b).count(); let union = words_a.union(&words_b).count(); if union == 0 { 0.0 } else { intersection as f64 / union as f64 } } #[cfg(test)] mod tests { use super::*; #[tokio::test] async fn test_kg_creation() { let kg = TemporalKG::new("ws://localhost:8000", "root", "root") .await .unwrap(); let stats = kg.get_statistics().await.unwrap(); assert_eq!(stats.total_records, 0); } #[tokio::test] async fn test_record_execution() { let kg = TemporalKG::new("ws://localhost:8000", "root", "root") .await .unwrap(); let record = ExecutionRecord { id: "exec-1".to_string(), task_id: "task-1".to_string(), agent_id: "agent-1".to_string(), agent_role: None, task_type: "coding".to_string(), description: "Write a Rust function".to_string(), duration_ms: 5000, input_tokens: 100, output_tokens: 250, cost_cents: 50, provider: "claude".to_string(), success: true, error: None, solution: Some("Use async/await pattern".to_string()), root_cause: None, timestamp: Utc::now(), }; kg.record_execution(record).await.unwrap(); let stats = kg.get_statistics().await.unwrap(); assert_eq!(stats.total_records, 1); assert_eq!(stats.total_successful, 1); } #[tokio::test] async fn test_query_similar_tasks() { let kg = TemporalKG::new("ws://localhost:8000", "root", "root") .await .unwrap(); let record1 = ExecutionRecord { id: "exec-1".to_string(), task_id: "task-1".to_string(), agent_id: "agent-1".to_string(), agent_role: None, task_type: "coding".to_string(), description: "Write a Rust function for data processing".to_string(), duration_ms: 5000, input_tokens: 100, output_tokens: 250, cost_cents: 60, provider: "claude".to_string(), success: true, error: None, solution: Some("Use async/await".to_string()), root_cause: None, timestamp: Utc::now(), }; kg.record_execution(record1).await.unwrap(); let similar = kg .query_similar_tasks("coding", "Write a Rust function for processing data") .await .unwrap(); assert!(!similar.is_empty()); } #[tokio::test] async fn test_agent_profile() { let kg = TemporalKG::new("ws://localhost:8000", "root", "root") .await .unwrap(); let record = ExecutionRecord { id: "exec-1".to_string(), task_id: "task-1".to_string(), agent_id: "agent-1".to_string(), agent_role: None, task_type: "coding".to_string(), description: "Write code".to_string(), duration_ms: 5000, input_tokens: 100, output_tokens: 250, cost_cents: 55, provider: "claude".to_string(), success: true, error: None, solution: None, root_cause: None, timestamp: Utc::now(), }; kg.record_execution(record).await.unwrap(); let profile = kg.get_agent_profile("agent-1").await.unwrap(); assert_eq!(profile.agent_id, "agent-1"); assert_eq!(profile.total_tasks, 1); assert_eq!(profile.success_count, 1); } }