Vapora/crates/vapora-a2a/tests/integration_test.rs

362 lines
10 KiB
Rust
Raw Normal View History

use std::sync::Arc;
use std::time::Duration;
use surrealdb::{
engine::remote::ws::{Client, Ws},
opt::auth::Root,
Surreal,
};
use tokio::time::{sleep, timeout};
use vapora_a2a::{
bridge::CoordinatorBridge,
protocol::{A2aMessage, A2aMessagePart, A2aTask, TaskState},
task_manager::TaskManager,
};
use vapora_agents::{
config::AgentConfig,
coordinator::AgentCoordinator,
nats_bridge::{NatsBridge, NatsBrokerConfig, TaskResult},
registry::AgentRegistry,
};
async fn setup_test_db() -> Surreal<Client> {
let db = Surreal::new::<Ws>("127.0.0.1:8000")
.await
.expect("Failed to connect to SurrealDB");
db.signin(Root {
username: "root",
password: "root",
})
.await
.expect("Failed to sign in");
db.use_ns("test")
.use_db("vapora_a2a_integration_test")
.await
.expect("Failed to use namespace");
db
}
async fn setup_test_nats_bridge(registry: Arc<AgentRegistry>) -> Arc<NatsBridge> {
let config = NatsBrokerConfig {
url: "nats://127.0.0.1:4222".to_string(),
stream_name: "VAPORA_TASKS_TEST".to_string(),
consumer_name: "vapora-a2a-integration-test".to_string(),
..NatsBrokerConfig::default()
};
Arc::new(
NatsBridge::connect(config, registry)
.await
.expect("Failed to connect NatsBridge"),
)
}
/// Test 1: Task persistence — tasks survive TaskManager restart
#[tokio::test]
#[ignore] // Requires SurrealDB running
async fn test_task_persistence_after_restart() {
let db = setup_test_db().await;
let task_manager = Arc::new(TaskManager::new(db.clone()));
let task = A2aTask {
id: "persistence-test-123".to_string(),
message: A2aMessage {
role: "user".to_string(),
parts: vec![A2aMessagePart::Text("Test persistence task".to_string())],
},
metadata: Default::default(),
};
task_manager
.create(task)
.await
.expect("Failed to create task");
// Simulate restart with a new TaskManager instance pointing to same DB
let task_manager2 = Arc::new(TaskManager::new(db.clone()));
let status = task_manager2
.get("persistence-test-123")
.await
.expect("Task not found after restart");
assert_eq!(status.id, "persistence-test-123");
assert_eq!(status.state, TaskState::Waiting.as_str());
let _ = db
.query("DELETE FROM a2a_tasks WHERE task_id = $task_id")
.bind(("task_id", "persistence-test-123"))
.await;
}
/// Test 2: JetStream result updates DB — NatsBridge receives TaskResult and
/// persists completion to SurrealDB
#[tokio::test]
#[ignore] // Requires SurrealDB + NATS running
async fn test_jetstream_task_completion_updates_db() {
let db = setup_test_db().await;
let registry = Arc::new(AgentRegistry::new(10));
let nats_bridge = setup_test_nats_bridge(registry.clone()).await;
let task_manager = Arc::new(TaskManager::new(db.clone()));
let config = AgentConfig::default();
let coordinator = Arc::new(
AgentCoordinator::new(config, registry)
.await
.expect("Failed to create coordinator"),
);
let bridge = Arc::new(CoordinatorBridge::new(
coordinator,
task_manager.clone(),
Some(nats_bridge.clone()),
));
bridge
.start_result_listener()
.await
.expect("Failed to start result listener");
let task_id = "jetstream-completion-test-456".to_string();
let task = A2aTask {
id: task_id.clone(),
message: A2aMessage {
role: "user".to_string(),
parts: vec![A2aMessagePart::Text(
"Test JetStream completion".to_string(),
)],
},
metadata: Default::default(),
};
task_manager
.create(task)
.await
.expect("Failed to create task");
// Publish TaskResult to JetStream via a separate raw client — simulates
// agent completing a task and publishing to vapora.tasks.completed.
let raw_client = async_nats::connect("127.0.0.1:4222")
.await
.expect("Failed to connect raw NATS client");
let js = async_nats::jetstream::new(raw_client);
let result = TaskResult {
task_id: task_id.clone(),
agent_id: "test-agent".to_string(),
result: "Test output from agent".to_string(),
success: true,
duration_ms: 500,
};
js.publish(
"vapora.tasks.completed".to_string(),
serde_json::to_vec(&result).unwrap().into(),
)
.await
.expect("Failed to publish to JetStream")
.await
.expect("Failed to receive JetStream ack");
// Allow the pull consumer to fetch and process the message
sleep(Duration::from_millis(1500)).await;
let status = task_manager
.get(&task_id)
.await
.expect("Failed to get task status");
assert_eq!(status.state, TaskState::Completed.as_str());
assert!(status.result.is_some());
let result_msg = status.result.unwrap();
assert_eq!(result_msg.message.parts.len(), 1);
if let A2aMessagePart::Text(text) = &result_msg.message.parts[0] {
assert_eq!(text, "Test output from agent");
} else {
panic!("Expected text message part");
}
let _ = db
.query("DELETE FROM a2a_tasks WHERE task_id = $task_id")
.bind(("task_id", task_id))
.await;
}
/// Test 3: Task state transitions work correctly (SurrealDB only)
#[tokio::test]
#[ignore] // Requires SurrealDB running
async fn test_task_state_transitions() {
let db = setup_test_db().await;
let task_manager = Arc::new(TaskManager::new(db.clone()));
let task_id = "state-transition-test-789".to_string();
let task = A2aTask {
id: task_id.clone(),
message: A2aMessage {
role: "user".to_string(),
parts: vec![A2aMessagePart::Text("Test state transitions".to_string())],
},
metadata: Default::default(),
};
task_manager
.create(task)
.await
.expect("Failed to create task");
let status = task_manager.get(&task_id).await.unwrap();
assert_eq!(status.state, TaskState::Waiting.as_str());
task_manager
.update_state(&task_id, TaskState::Working)
.await
.expect("Failed to update to working");
let status = task_manager.get(&task_id).await.unwrap();
assert_eq!(status.state, TaskState::Working.as_str());
let result = vapora_a2a::protocol::A2aTaskResult {
message: A2aMessage {
role: "assistant".to_string(),
parts: vec![A2aMessagePart::Text("Task completed".to_string())],
},
artifacts: None,
};
task_manager
.complete(&task_id, result)
.await
.expect("Failed to complete task");
let status = task_manager.get(&task_id).await.unwrap();
assert_eq!(status.state, TaskState::Completed.as_str());
assert!(status.result.is_some());
let _ = db
.query("DELETE FROM a2a_tasks WHERE task_id = $task_id")
.bind(("task_id", task_id))
.await;
}
/// Test 4: Task failure handling (SurrealDB only)
#[tokio::test]
#[ignore] // Requires SurrealDB running
async fn test_task_failure_handling() {
let db = setup_test_db().await;
let task_manager = Arc::new(TaskManager::new(db.clone()));
let task_id = "failure-test-999".to_string();
let task = A2aTask {
id: task_id.clone(),
message: A2aMessage {
role: "user".to_string(),
parts: vec![A2aMessagePart::Text("Test failure handling".to_string())],
},
metadata: Default::default(),
};
task_manager
.create(task)
.await
.expect("Failed to create task");
let error = vapora_a2a::protocol::A2aErrorObj {
code: -1,
message: "Test error message".to_string(),
};
task_manager
.fail(&task_id, error)
.await
.expect("Failed to fail task");
let status = task_manager.get(&task_id).await.unwrap();
assert_eq!(status.state, TaskState::Failed.as_str());
assert!(status.error.is_some());
assert_eq!(status.error.unwrap().message, "Test error message");
let _ = db
.query("DELETE FROM a2a_tasks WHERE task_id = $task_id")
.bind(("task_id", task_id))
.await;
}
/// Test 5: End-to-end task dispatch with timeout
#[tokio::test]
#[ignore] // Requires SurrealDB + NATS + Agent running
async fn test_end_to_end_task_dispatch() {
let db = setup_test_db().await;
let registry = Arc::new(AgentRegistry::new(10));
let nats_bridge = setup_test_nats_bridge(registry.clone()).await;
let task_manager = Arc::new(TaskManager::new(db.clone()));
let config = AgentConfig::default();
let coordinator = Arc::new(
AgentCoordinator::new(config, registry)
.await
.expect("Failed to create coordinator"),
);
let bridge = Arc::new(CoordinatorBridge::new(
coordinator,
task_manager.clone(),
Some(nats_bridge),
));
bridge
.start_result_listener()
.await
.expect("Failed to start result listener");
let task = A2aTask {
id: "e2e-test-task-001".to_string(),
message: A2aMessage {
role: "user".to_string(),
parts: vec![A2aMessagePart::Text(
"Create hello world function".to_string(),
)],
},
metadata: Default::default(),
};
let task_id = bridge
.dispatch(task)
.await
.expect("Failed to dispatch task");
assert!(!task_id.is_empty());
let status_result = timeout(Duration::from_secs(30), async {
loop {
let status = task_manager.get(&task_id).await.unwrap();
if status.state == TaskState::Completed.as_str()
|| status.state == TaskState::Failed.as_str()
{
return status;
}
sleep(Duration::from_millis(500)).await;
}
})
.await;
assert!(
status_result.is_ok(),
"Task did not complete within 30 seconds"
);
let final_status = status_result.unwrap();
assert_eq!(final_status.state, TaskState::Completed.as_str());
let _ = db
.query("DELETE FROM a2a_tasks WHERE task_id = $task_id")
.bind(("task_id", task_id))
.await;
}