374 lines
10 KiB
Rust
374 lines
10 KiB
Rust
|
|
use std::sync::Arc;
|
||
|
|
use std::time::Duration;
|
||
|
|
|
||
|
|
use surrealdb::{
|
||
|
|
engine::remote::ws::{Client, Ws},
|
||
|
|
opt::auth::Root,
|
||
|
|
Surreal,
|
||
|
|
};
|
||
|
|
use tokio::time::{sleep, timeout};
|
||
|
|
use vapora_a2a::{
|
||
|
|
bridge::CoordinatorBridge,
|
||
|
|
protocol::{A2aMessage, A2aMessagePart, A2aTask, TaskState},
|
||
|
|
task_manager::TaskManager,
|
||
|
|
};
|
||
|
|
use vapora_agents::{
|
||
|
|
config::AgentConfig, coordinator::AgentCoordinator, messages::AgentMessage,
|
||
|
|
messages::TaskCompleted, registry::AgentRegistry,
|
||
|
|
};
|
||
|
|
|
||
|
|
/// Setup test database connection
|
||
|
|
async fn setup_test_db() -> Surreal<Client> {
|
||
|
|
let db = Surreal::new::<Ws>("127.0.0.1:8000")
|
||
|
|
.await
|
||
|
|
.expect("Failed to connect to SurrealDB");
|
||
|
|
|
||
|
|
db.signin(Root {
|
||
|
|
username: "root",
|
||
|
|
password: "root",
|
||
|
|
})
|
||
|
|
.await
|
||
|
|
.expect("Failed to sign in");
|
||
|
|
|
||
|
|
db.use_ns("test")
|
||
|
|
.use_db("vapora_a2a_integration_test")
|
||
|
|
.await
|
||
|
|
.expect("Failed to use namespace");
|
||
|
|
|
||
|
|
db
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Setup test NATS connection
|
||
|
|
async fn setup_test_nats() -> async_nats::Client {
|
||
|
|
async_nats::connect("127.0.0.1:4222")
|
||
|
|
.await
|
||
|
|
.expect("Failed to connect to NATS")
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Test 1: Task persistence - tasks survive restarts
|
||
|
|
#[tokio::test]
|
||
|
|
#[ignore] // Requires SurrealDB running
|
||
|
|
async fn test_task_persistence_after_restart() {
|
||
|
|
let db = setup_test_db().await;
|
||
|
|
let task_manager = Arc::new(TaskManager::new(db.clone()));
|
||
|
|
|
||
|
|
let task = A2aTask {
|
||
|
|
id: "persistence-test-123".to_string(),
|
||
|
|
message: A2aMessage {
|
||
|
|
role: "user".to_string(),
|
||
|
|
parts: vec![A2aMessagePart::Text("Test persistence task".to_string())],
|
||
|
|
},
|
||
|
|
metadata: Default::default(),
|
||
|
|
};
|
||
|
|
|
||
|
|
// Create task
|
||
|
|
task_manager
|
||
|
|
.create(task)
|
||
|
|
.await
|
||
|
|
.expect("Failed to create task");
|
||
|
|
|
||
|
|
// Simulate restart by creating new TaskManager instance
|
||
|
|
let task_manager2 = Arc::new(TaskManager::new(db.clone()));
|
||
|
|
|
||
|
|
// Verify task still exists
|
||
|
|
let status = task_manager2
|
||
|
|
.get("persistence-test-123")
|
||
|
|
.await
|
||
|
|
.expect("Failed to get status after restart");
|
||
|
|
|
||
|
|
assert_eq!(status.id, "persistence-test-123");
|
||
|
|
assert_eq!(status.state, TaskState::Waiting.as_str());
|
||
|
|
|
||
|
|
// Cleanup
|
||
|
|
let _ = db
|
||
|
|
.query("DELETE FROM a2a_tasks WHERE task_id = $task_id")
|
||
|
|
.bind(("task_id", "persistence-test-123"))
|
||
|
|
.await;
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Test 2: NATS task completion updates DB correctly
|
||
|
|
#[tokio::test]
|
||
|
|
#[ignore] // Requires SurrealDB + NATS running
|
||
|
|
async fn test_nats_task_completion_updates_db() {
|
||
|
|
let db = setup_test_db().await;
|
||
|
|
let nats = setup_test_nats().await;
|
||
|
|
|
||
|
|
let task_manager = Arc::new(TaskManager::new(db.clone()));
|
||
|
|
let registry = Arc::new(AgentRegistry::new(10));
|
||
|
|
let config = AgentConfig::default();
|
||
|
|
let coordinator = Arc::new(AgentCoordinator::new(config, registry).await.unwrap());
|
||
|
|
|
||
|
|
let bridge = Arc::new(CoordinatorBridge::new(
|
||
|
|
coordinator,
|
||
|
|
task_manager.clone(),
|
||
|
|
Some(nats.clone()),
|
||
|
|
));
|
||
|
|
|
||
|
|
bridge
|
||
|
|
.start_result_listener()
|
||
|
|
.await
|
||
|
|
.expect("Failed to start listener");
|
||
|
|
|
||
|
|
let task_id = "nats-completion-test-456".to_string();
|
||
|
|
|
||
|
|
// Create task
|
||
|
|
let task = A2aTask {
|
||
|
|
id: task_id.clone(),
|
||
|
|
message: A2aMessage {
|
||
|
|
role: "user".to_string(),
|
||
|
|
parts: vec![A2aMessagePart::Text("Test NATS completion".to_string())],
|
||
|
|
},
|
||
|
|
metadata: Default::default(),
|
||
|
|
};
|
||
|
|
|
||
|
|
task_manager
|
||
|
|
.create(task)
|
||
|
|
.await
|
||
|
|
.expect("Failed to create task");
|
||
|
|
|
||
|
|
// Publish TaskCompleted message to NATS
|
||
|
|
let task_completed = TaskCompleted {
|
||
|
|
task_id: task_id.clone(),
|
||
|
|
agent_id: "test-agent".to_string(),
|
||
|
|
result: "Test output from agent".to_string(),
|
||
|
|
artifacts: vec!["/path/to/artifact.txt".to_string()],
|
||
|
|
tokens_used: 100,
|
||
|
|
duration_ms: 500,
|
||
|
|
completed_at: chrono::Utc::now(),
|
||
|
|
};
|
||
|
|
|
||
|
|
let message = AgentMessage::TaskCompleted(task_completed);
|
||
|
|
nats.publish(
|
||
|
|
"vapora.tasks.completed",
|
||
|
|
serde_json::to_vec(&message).unwrap().into(),
|
||
|
|
)
|
||
|
|
.await
|
||
|
|
.expect("Failed to publish");
|
||
|
|
|
||
|
|
// Wait for DB update (give NATS subscriber time to process)
|
||
|
|
sleep(Duration::from_millis(1000)).await;
|
||
|
|
|
||
|
|
// Verify DB updated
|
||
|
|
let status = task_manager
|
||
|
|
.get(&task_id)
|
||
|
|
.await
|
||
|
|
.expect("Failed to get status");
|
||
|
|
|
||
|
|
assert_eq!(status.state, TaskState::Completed.as_str());
|
||
|
|
assert!(status.result.is_some());
|
||
|
|
|
||
|
|
let result = status.result.unwrap();
|
||
|
|
assert_eq!(result.message.parts.len(), 1);
|
||
|
|
|
||
|
|
if let A2aMessagePart::Text(text) = &result.message.parts[0] {
|
||
|
|
assert_eq!(text, "Test output from agent");
|
||
|
|
} else {
|
||
|
|
panic!("Expected text message part");
|
||
|
|
}
|
||
|
|
|
||
|
|
assert!(result.artifacts.is_some());
|
||
|
|
assert_eq!(result.artifacts.as_ref().unwrap().len(), 1);
|
||
|
|
|
||
|
|
// Cleanup
|
||
|
|
let _ = db
|
||
|
|
.query("DELETE FROM a2a_tasks WHERE task_id = $task_id")
|
||
|
|
.bind(("task_id", task_id))
|
||
|
|
.await;
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Test 3: Task state transitions work correctly
|
||
|
|
#[tokio::test]
|
||
|
|
#[ignore] // Requires SurrealDB running
|
||
|
|
async fn test_task_state_transitions() {
|
||
|
|
let db = setup_test_db().await;
|
||
|
|
let task_manager = Arc::new(TaskManager::new(db.clone()));
|
||
|
|
|
||
|
|
let task_id = "state-transition-test-789".to_string();
|
||
|
|
|
||
|
|
let task = A2aTask {
|
||
|
|
id: task_id.clone(),
|
||
|
|
message: A2aMessage {
|
||
|
|
role: "user".to_string(),
|
||
|
|
parts: vec![A2aMessagePart::Text("Test state transitions".to_string())],
|
||
|
|
},
|
||
|
|
metadata: Default::default(),
|
||
|
|
};
|
||
|
|
|
||
|
|
// Create task (waiting state)
|
||
|
|
task_manager
|
||
|
|
.create(task)
|
||
|
|
.await
|
||
|
|
.expect("Failed to create task");
|
||
|
|
|
||
|
|
let status = task_manager.get(&task_id).await.unwrap();
|
||
|
|
assert_eq!(status.state, TaskState::Waiting.as_str());
|
||
|
|
|
||
|
|
// Transition to working
|
||
|
|
task_manager
|
||
|
|
.update_state(&task_id, TaskState::Working)
|
||
|
|
.await
|
||
|
|
.expect("Failed to update to working");
|
||
|
|
|
||
|
|
let status = task_manager.get(&task_id).await.unwrap();
|
||
|
|
assert_eq!(status.state, TaskState::Working.as_str());
|
||
|
|
|
||
|
|
// Complete task
|
||
|
|
let result = vapora_a2a::protocol::A2aTaskResult {
|
||
|
|
message: A2aMessage {
|
||
|
|
role: "assistant".to_string(),
|
||
|
|
parts: vec![A2aMessagePart::Text("Task completed".to_string())],
|
||
|
|
},
|
||
|
|
artifacts: None,
|
||
|
|
};
|
||
|
|
|
||
|
|
task_manager
|
||
|
|
.complete(&task_id, result)
|
||
|
|
.await
|
||
|
|
.expect("Failed to complete task");
|
||
|
|
|
||
|
|
let status = task_manager.get(&task_id).await.unwrap();
|
||
|
|
assert_eq!(status.state, TaskState::Completed.as_str());
|
||
|
|
assert!(status.result.is_some());
|
||
|
|
|
||
|
|
// Cleanup
|
||
|
|
let _ = db
|
||
|
|
.query("DELETE FROM a2a_tasks WHERE task_id = $task_id")
|
||
|
|
.bind(("task_id", task_id))
|
||
|
|
.await;
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Test 4: Task failure handling
|
||
|
|
#[tokio::test]
|
||
|
|
#[ignore] // Requires SurrealDB running
|
||
|
|
async fn test_task_failure_handling() {
|
||
|
|
let db = setup_test_db().await;
|
||
|
|
let task_manager = Arc::new(TaskManager::new(db.clone()));
|
||
|
|
|
||
|
|
let task_id = "failure-test-999".to_string();
|
||
|
|
|
||
|
|
let task = A2aTask {
|
||
|
|
id: task_id.clone(),
|
||
|
|
message: A2aMessage {
|
||
|
|
role: "user".to_string(),
|
||
|
|
parts: vec![A2aMessagePart::Text("Test failure handling".to_string())],
|
||
|
|
},
|
||
|
|
metadata: Default::default(),
|
||
|
|
};
|
||
|
|
|
||
|
|
task_manager
|
||
|
|
.create(task)
|
||
|
|
.await
|
||
|
|
.expect("Failed to create task");
|
||
|
|
|
||
|
|
// Fail task
|
||
|
|
let error = vapora_a2a::protocol::A2aErrorObj {
|
||
|
|
code: -1,
|
||
|
|
message: "Test error message".to_string(),
|
||
|
|
};
|
||
|
|
|
||
|
|
task_manager
|
||
|
|
.fail(&task_id, error)
|
||
|
|
.await
|
||
|
|
.expect("Failed to fail task");
|
||
|
|
|
||
|
|
let status = task_manager.get(&task_id).await.unwrap();
|
||
|
|
assert_eq!(status.state, TaskState::Failed.as_str());
|
||
|
|
assert!(status.error.is_some());
|
||
|
|
assert_eq!(status.error.unwrap().message, "Test error message");
|
||
|
|
|
||
|
|
// Cleanup
|
||
|
|
let _ = db
|
||
|
|
.query("DELETE FROM a2a_tasks WHERE task_id = $task_id")
|
||
|
|
.bind(("task_id", task_id))
|
||
|
|
.await;
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Test 5: End-to-end task dispatch with timeout
|
||
|
|
#[tokio::test]
|
||
|
|
#[ignore] // Requires SurrealDB + NATS + Agent running
|
||
|
|
async fn test_end_to_end_task_dispatch() {
|
||
|
|
let db = setup_test_db().await;
|
||
|
|
let nats = setup_test_nats().await;
|
||
|
|
|
||
|
|
let task_manager = Arc::new(TaskManager::new(db.clone()));
|
||
|
|
let registry = Arc::new(AgentRegistry::new(10));
|
||
|
|
let config = AgentConfig::default();
|
||
|
|
let coordinator = Arc::new(AgentCoordinator::new(config, registry).await.unwrap());
|
||
|
|
|
||
|
|
let bridge = Arc::new(CoordinatorBridge::new(
|
||
|
|
coordinator,
|
||
|
|
task_manager.clone(),
|
||
|
|
Some(nats.clone()),
|
||
|
|
));
|
||
|
|
|
||
|
|
bridge
|
||
|
|
.start_result_listener()
|
||
|
|
.await
|
||
|
|
.expect("Failed to start listener");
|
||
|
|
|
||
|
|
let task = A2aTask {
|
||
|
|
id: "e2e-test-task-001".to_string(),
|
||
|
|
message: A2aMessage {
|
||
|
|
role: "user".to_string(),
|
||
|
|
parts: vec![A2aMessagePart::Text(
|
||
|
|
"Create hello world function".to_string(),
|
||
|
|
)],
|
||
|
|
},
|
||
|
|
metadata: Default::default(),
|
||
|
|
};
|
||
|
|
|
||
|
|
// Dispatch task
|
||
|
|
let task_id = bridge
|
||
|
|
.dispatch(task)
|
||
|
|
.await
|
||
|
|
.expect("Failed to dispatch task");
|
||
|
|
|
||
|
|
// Poll for completion with timeout
|
||
|
|
let result = timeout(Duration::from_secs(60), async {
|
||
|
|
loop {
|
||
|
|
let status = bridge
|
||
|
|
.get_task(&task_id)
|
||
|
|
.await
|
||
|
|
.expect("Failed to get status");
|
||
|
|
|
||
|
|
match task_state_from_str(&status.state) {
|
||
|
|
TaskState::Completed => return Ok(status),
|
||
|
|
TaskState::Failed => return Err(format!("Task failed: {:?}", status.error)),
|
||
|
|
_ => {
|
||
|
|
sleep(Duration::from_millis(500)).await;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
})
|
||
|
|
.await;
|
||
|
|
|
||
|
|
match result {
|
||
|
|
Ok(Ok(status)) => {
|
||
|
|
println!("Task completed successfully: {:?}", status);
|
||
|
|
assert_eq!(status.state, TaskState::Completed.as_str());
|
||
|
|
}
|
||
|
|
Ok(Err(e)) => panic!("Task failed: {}", e),
|
||
|
|
Err(_) => {
|
||
|
|
println!(
|
||
|
|
"Task did not complete within 60 seconds (this is expected if no agent is running)"
|
||
|
|
);
|
||
|
|
// Cleanup partial task
|
||
|
|
let _ = db
|
||
|
|
.query("DELETE FROM a2a_tasks WHERE task_id = $task_id")
|
||
|
|
.bind(("task_id", task_id))
|
||
|
|
.await;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Helper to convert string to TaskState
|
||
|
|
fn task_state_from_str(s: &str) -> TaskState {
|
||
|
|
match s {
|
||
|
|
"waiting" => TaskState::Waiting,
|
||
|
|
"working" => TaskState::Working,
|
||
|
|
"completed" => TaskState::Completed,
|
||
|
|
"failed" => TaskState::Failed,
|
||
|
|
_ => TaskState::Waiting,
|
||
|
|
}
|
||
|
|
}
|