prvng_platform/crates/ai-service/src/tool_integration.rs

204 lines
6.2 KiB
Rust
Raw Normal View History

//! Tool integration and hybrid execution mode
//!
//! Analyzes RAG responses to suggest tool executions and enriches answers with
//! tool results.
use serde_json::Value;
/// Tool suggestion from RAG answer analysis
#[derive(Debug, Clone)]
pub struct ToolSuggestion {
pub tool_name: String,
pub confidence: f32,
pub args: Value,
}
/// Analyzes RAG answers to suggest relevant tools
///
/// Uses keyword matching and question pattern detection to identify tools that
/// might be useful.
pub fn analyze_for_tools(_answer: &str, question: &str) -> Vec<ToolSuggestion> {
let mut suggestions = Vec::new();
// System status patterns
if question_matches_any(
question,
&["status", "health", "running", "check", "what's"],
) {
suggestions.push(ToolSuggestion {
tool_name: "guidance_check_system_status".to_string(),
confidence: 0.7,
args: serde_json::json!({}),
});
}
// Configuration validation patterns
if question_matches_any(
question,
&["valid", "config", "configuration", "validate", "verify"],
) {
suggestions.push(ToolSuggestion {
tool_name: "guidance_validate_config".to_string(),
confidence: 0.6,
args: serde_json::json!({
"config_path": "/etc/provisioning/config.toml"
}),
});
}
// Documentation patterns
if question_matches_any(question, &["doc", "help", "guide", "tutorial", "how to"]) {
suggestions.push(ToolSuggestion {
tool_name: "guidance_find_docs".to_string(),
confidence: 0.5,
args: serde_json::json!({
"query": extract_main_topic(question)
}),
});
}
// Troubleshooting patterns
if question_matches_any(
question,
&["error", "fail", "problem", "issue", "fix", "debug"],
) {
suggestions.push(ToolSuggestion {
tool_name: "guidance_troubleshoot".to_string(),
confidence: 0.65,
args: serde_json::json!({
"error": extract_error_description(question)
}),
});
}
// Next action suggestions
if question_matches_any(question, &["next", "then", "after", "what should"]) {
suggestions.push(ToolSuggestion {
tool_name: "guidance_suggest_next_action".to_string(),
confidence: 0.55,
args: serde_json::json!({}),
});
}
// RAG tools based on keywords
if question_matches_any(question, &["search", "find", "look for", "retrieve"]) {
suggestions.push(ToolSuggestion {
tool_name: "rag_semantic_search".to_string(),
confidence: 0.6,
args: serde_json::json!({
"query": question.to_string(),
"top_k": 5
}),
});
}
// Filter out low confidence suggestions and sort by confidence descending
suggestions.sort_by(|a, b| b.confidence.partial_cmp(&a.confidence).unwrap());
suggestions
}
/// Checks if question contains any of the keywords
fn question_matches_any(question: &str, keywords: &[&str]) -> bool {
let lower = question.to_lowercase();
keywords.iter().any(|kw| lower.contains(kw))
}
/// Extracts main topic from question for search
fn extract_main_topic(question: &str) -> String {
// Simple heuristic: take the longest word or meaningful phrase
let words: Vec<&str> = question.split_whitespace().collect();
if words.is_empty() {
"provisioning".to_string()
} else {
words
.iter()
.max_by_key(|w| w.len())
.map(|w| w.to_string())
.unwrap_or_else(|| "provisioning".to_string())
}
}
/// Extracts error description from question
fn extract_error_description(question: &str) -> String {
// Take the full question as error context
question.to_string()
}
/// Enriches RAG answer with tool execution results
///
/// Appends tool execution results to the original answer.
pub fn enrich_answer_with_results(mut answer: String, tool_results: &[(String, Value)]) -> String {
if tool_results.is_empty() {
return answer;
}
answer.push_str("\n\n---\n\n**Tool Results:**\n\n");
for (tool_name, result) in tool_results {
answer.push_str(&format!("**{}:**\n", tool_name));
if let Some(status) = result.get("status") {
answer.push_str(&format!("Status: {}\n", status));
}
// Add tool-specific result formatting
if let Some(msg) = result.get("message") {
answer.push_str(&format!("{}\n", msg));
}
if let Some(suggestion) = result.get("suggestion") {
answer.push_str(&format!("{}\n", suggestion));
}
if let Some(diagnosis) = result.get("diagnosis") {
answer.push_str(&format!("{}\n", diagnosis));
}
answer.push('\n');
}
answer
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_system_status_detection() {
let question = "What is the current system status?";
let suggestions = analyze_for_tools("some answer", question);
assert!(suggestions
.iter()
.any(|s| s.tool_name == "guidance_check_system_status"));
}
#[test]
fn test_config_validation_detection() {
let question = "Is my configuration valid?";
let suggestions = analyze_for_tools("some answer", question);
assert!(suggestions
.iter()
.any(|s| s.tool_name == "guidance_validate_config"));
}
#[test]
fn test_doc_search_detection() {
let question = "How do I use the provisioning guide?";
let suggestions = analyze_for_tools("some answer", question);
assert!(suggestions
.iter()
.any(|s| s.tool_name == "guidance_find_docs"));
}
#[test]
fn test_answer_enrichment() {
let original = "RAG answer about provisioning".to_string();
let results = vec![(
"test_tool".to_string(),
serde_json::json!({"status": "success", "message": "Tool ran"}),
)];
let enriched = enrich_answer_with_results(original, &results);
assert!(enriched.contains("Tool Results"));
assert!(enriched.contains("test_tool"));
}
}