168 lines
5.8 KiB
Rust
168 lines
5.8 KiB
Rust
|
|
//! Example: RAG Agent with Function Calling / Tool Use
|
||
|
|
//!
|
||
|
|
//! Demonstrates how to use the ToolRegistry and ToolEnabledRagAgent to:
|
||
|
|
//! - Register tools for Claude to invoke
|
||
|
|
//! - Execute tools with security validation
|
||
|
|
//! - Track tool usage statistics
|
||
|
|
//! - Integrate tool results into RAG responses
|
||
|
|
|
||
|
|
#![allow(unused_imports)]
|
||
|
|
|
||
|
|
use std::sync::Arc;
|
||
|
|
|
||
|
|
use provisioning_rag::{
|
||
|
|
CreateServerTool, RagTool, TaskservManagementTool, ToolCall, ToolDefinition, ToolInput,
|
||
|
|
ToolOutput, ToolRegistry, WorkspaceStatusTool,
|
||
|
|
};
|
||
|
|
|
||
|
|
#[tokio::main]
|
||
|
|
async fn main() -> anyhow::Result<()> {
|
||
|
|
// Initialize logging
|
||
|
|
tracing_subscriber::fmt()
|
||
|
|
.with_max_level(tracing::Level::INFO)
|
||
|
|
.init();
|
||
|
|
|
||
|
|
println!("=== RAG Agent with Function Calling ===\n");
|
||
|
|
|
||
|
|
// 1. Create tool registry
|
||
|
|
println!("1. Setting up tool registry...");
|
||
|
|
let registry = Arc::new(ToolRegistry::new());
|
||
|
|
println!(" ✓ Tool registry created\n");
|
||
|
|
|
||
|
|
// 2. Register core tools
|
||
|
|
println!("2. Registering tools...");
|
||
|
|
registry.register(Arc::new(CreateServerTool::new())).await?;
|
||
|
|
registry
|
||
|
|
.register(Arc::new(WorkspaceStatusTool::new("production".to_string())))
|
||
|
|
.await?;
|
||
|
|
registry
|
||
|
|
.register(Arc::new(TaskservManagementTool::new()))
|
||
|
|
.await?;
|
||
|
|
println!(" ✓ 3 tools registered\n");
|
||
|
|
|
||
|
|
// 3. Display available tools
|
||
|
|
println!("3. Available tools:");
|
||
|
|
let definitions = registry.get_definitions().await;
|
||
|
|
for def in &definitions {
|
||
|
|
println!(" - {} ({})", def.name, def.description);
|
||
|
|
}
|
||
|
|
println!();
|
||
|
|
|
||
|
|
// 4. Execute tools with validation
|
||
|
|
println!("=== Tool Execution Examples ===\n");
|
||
|
|
|
||
|
|
// Example 1: Workspace Status (no auth required)
|
||
|
|
println!("Example 1: Get workspace status");
|
||
|
|
let input = ToolInput {
|
||
|
|
params: serde_json::json!({"include_metrics": true}),
|
||
|
|
};
|
||
|
|
let output = registry
|
||
|
|
.call_tool("get_workspace_status", input, "user@example.com")
|
||
|
|
.await?;
|
||
|
|
println!("Result: {}", output.result);
|
||
|
|
println!("Success: {}\n", output.success);
|
||
|
|
|
||
|
|
// Example 2: Create Server (requires admin)
|
||
|
|
println!("Example 2: Create server (non-admin user)");
|
||
|
|
let input = ToolInput {
|
||
|
|
params: serde_json::json!({
|
||
|
|
"hostname": "web-01",
|
||
|
|
"cores": 4,
|
||
|
|
"memory_gb": 8,
|
||
|
|
"region": "us-east-1"
|
||
|
|
}),
|
||
|
|
};
|
||
|
|
let output = registry
|
||
|
|
.call_tool("create_server", input, "user@example.com")
|
||
|
|
.await?;
|
||
|
|
println!("Result: {}", output.result);
|
||
|
|
if let Some(error) = &output.error {
|
||
|
|
println!("Error (expected): {}", error);
|
||
|
|
}
|
||
|
|
println!();
|
||
|
|
|
||
|
|
// Example 3: Create Server (admin user)
|
||
|
|
println!("Example 3: Create server (admin user)");
|
||
|
|
let input = ToolInput {
|
||
|
|
params: serde_json::json!({
|
||
|
|
"hostname": "web-02",
|
||
|
|
"cores": 8,
|
||
|
|
"memory_gb": 16,
|
||
|
|
"region": "eu-west-1"
|
||
|
|
}),
|
||
|
|
};
|
||
|
|
let output = registry
|
||
|
|
.call_tool("create_server", input, "admin@example.com")
|
||
|
|
.await?;
|
||
|
|
println!("Result: {}", output.result);
|
||
|
|
println!("Success: {}\n", output.success);
|
||
|
|
|
||
|
|
// Example 4: Manage Taskserv
|
||
|
|
println!("Example 4: Manage taskserv");
|
||
|
|
let input = ToolInput {
|
||
|
|
params: serde_json::json!({
|
||
|
|
"operation": "create",
|
||
|
|
"service": "kubernetes",
|
||
|
|
"version": "1.28.0"
|
||
|
|
}),
|
||
|
|
};
|
||
|
|
let output = registry
|
||
|
|
.call_tool("manage_taskserv", input, "user@example.com")
|
||
|
|
.await?;
|
||
|
|
println!("Result: {}", output.result);
|
||
|
|
println!("Success: {}\n", output.success);
|
||
|
|
|
||
|
|
// 5. Display tool statistics
|
||
|
|
println!("=== Tool Call Statistics ===\n");
|
||
|
|
let stats = registry.get_stats().await;
|
||
|
|
println!("Total calls: {}", stats.total_calls);
|
||
|
|
println!("Successful calls: {}", stats.successful_calls);
|
||
|
|
println!("Failed calls: {}", stats.failed_calls);
|
||
|
|
println!("Blocked calls (auth/rate limit): {}", stats.blocked_calls);
|
||
|
|
println!();
|
||
|
|
|
||
|
|
println!("Calls by tool:");
|
||
|
|
for (tool, count) in stats.calls_by_tool {
|
||
|
|
println!(" - {}: {} calls", tool, count);
|
||
|
|
}
|
||
|
|
println!();
|
||
|
|
|
||
|
|
// 6. Demonstrate security validation
|
||
|
|
println!("=== Security Validation ===\n");
|
||
|
|
println!("Tool registration ensures:");
|
||
|
|
println!("✓ Access control (admin vs user)");
|
||
|
|
println!("✓ Rate limiting (prevents abuse)");
|
||
|
|
println!("✓ Audit logging (tracks all invocations)");
|
||
|
|
println!("✓ Input validation (rejects invalid params)");
|
||
|
|
println!();
|
||
|
|
|
||
|
|
// 7. Tool integration workflow
|
||
|
|
println!("=== Tool Integration Workflow ===\n");
|
||
|
|
println!("1. Claude generates response with tool_use tags");
|
||
|
|
println!("2. Agent parses tool calls from Claude response");
|
||
|
|
println!("3. Registry validates user access and rate limits");
|
||
|
|
println!("4. Tool executes with validated parameters");
|
||
|
|
println!("5. Results logged to audit trail");
|
||
|
|
println!("6. Results integrated into final response");
|
||
|
|
println!("7. Statistics updated for monitoring\n");
|
||
|
|
|
||
|
|
// 8. Production integration
|
||
|
|
println!("=== Production Integration ===\n");
|
||
|
|
println!("In production RAG system:");
|
||
|
|
println!("- Tools are registered at agent initialization");
|
||
|
|
println!("- Claude response is parsed for tool_use tags");
|
||
|
|
println!("- Tools execute in controlled environment");
|
||
|
|
println!("- Results enhance RAG responses with actions");
|
||
|
|
println!("- Audit logs enable compliance and debugging");
|
||
|
|
println!("- Statistics drive monitoring and optimization\n");
|
||
|
|
|
||
|
|
println!("=== Tool System Benefits ===\n");
|
||
|
|
println!("✓ Extensible: Add new tools by implementing RagTool trait");
|
||
|
|
println!("✓ Secure: Access control, rate limiting, input validation");
|
||
|
|
println!("✓ Observable: Complete audit logging and statistics");
|
||
|
|
println!("✓ Flexible: Custom tools for any system operation");
|
||
|
|
println!("✓ Production-ready: Error handling, async support, testing\n");
|
||
|
|
|
||
|
|
Ok(())
|
||
|
|
}
|