806 lines
26 KiB
Rust
Raw Permalink Normal View History

2025-10-07 10:59:52 +01:00
//! Test utilities and helpers for multi-storage orchestrator testing
//!
//! This module provides shared test utilities, mock implementations,
//! test data generators, and helper functions for testing all storage backends.
use async_trait::async_trait;
use chrono::{Duration, Utc};
use futures::stream::{self, BoxStream};
use serde_json::Value;
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::{Arc, Mutex};
use tempfile::TempDir;
use uuid::Uuid;
// Re-export commonly used types
pub use tempfile;
pub use tokio_test;
// Import orchestrator types
use orchestrator::{TaskStatus, WorkflowTask};
use orchestrator::storage::{
create_storage, available_storage_types, StorageConfig, TaskStorage, StorageResult,
StorageError, StorageStatistics, TimeRange, AuditEntry, Metric, AuthToken,
TaskEvent, TaskEventType
};
/// Test data generator for creating consistent test objects
pub struct TestDataGenerator {
counter: Arc<Mutex<u64>>,
}
impl TestDataGenerator {
pub fn new() -> Self {
Self {
counter: Arc::new(Mutex::new(0)),
}
}
/// Generate a unique task ID
pub fn task_id(&self) -> String {
let mut counter = self.counter.lock().unwrap();
*counter += 1;
format!("test_task_{:04}", counter)
}
/// Generate a workflow task with default values
pub fn workflow_task(&self) -> WorkflowTask {
let id = self.task_id();
WorkflowTask {
id: id.clone(),
name: format!("Test Task {}", id),
command: "echo".to_string(),
args: vec!["hello".to_string()],
dependencies: vec![],
status: TaskStatus::Pending,
created_at: Utc::now(),
started_at: None,
completed_at: None,
output: None,
error: None,
}
}
/// Generate a workflow task with custom status
pub fn workflow_task_with_status(&self, status: TaskStatus) -> WorkflowTask {
let mut task = self.workflow_task();
task.status = status;
match status {
TaskStatus::Running => {
task.started_at = Some(Utc::now());
}
TaskStatus::Completed => {
task.started_at = Some(Utc::now() - Duration::minutes(5));
task.completed_at = Some(Utc::now());
task.output = Some("Test output".to_string());
}
TaskStatus::Failed => {
task.started_at = Some(Utc::now() - Duration::minutes(3));
task.completed_at = Some(Utc::now());
task.error = Some("Test error".to_string());
}
_ => {}
}
task
}
/// Generate a workflow task with dependencies
pub fn workflow_task_with_deps(&self, deps: Vec<String>) -> WorkflowTask {
let mut task = self.workflow_task();
task.dependencies = deps;
task
}
/// Generate multiple tasks with different statuses
pub fn workflow_tasks_batch(&self, count: usize) -> Vec<WorkflowTask> {
let statuses = [
TaskStatus::Pending,
TaskStatus::Running,
TaskStatus::Completed,
TaskStatus::Failed,
];
(0..count)
.map(|i| {
let status = statuses[i % statuses.len()].clone();
self.workflow_task_with_status(status)
})
.collect()
}
/// Generate an audit entry
pub fn audit_entry(&self, task_id: String) -> AuditEntry {
AuditEntry {
id: Uuid::new_v4().to_string(),
task_id,
operation: "test_operation".to_string(),
old_status: Some(TaskStatus::Pending),
new_status: Some(TaskStatus::Running),
user_id: Some("test_user".to_string()),
timestamp: Utc::now(),
metadata: {
let mut meta = HashMap::new();
meta.insert("test_key".to_string(), "test_value".to_string());
meta
},
}
}
/// Generate a metric
pub fn metric(&self, name: &str, value: f64) -> Metric {
Metric {
name: name.to_string(),
value,
tags: {
let mut tags = HashMap::new();
tags.insert("environment".to_string(), "test".to_string());
tags
},
timestamp: Utc::now(),
}
}
/// Generate an auth token
pub fn auth_token(&self, user_id: &str) -> AuthToken {
AuthToken {
token: Uuid::new_v4().to_string(),
user_id: user_id.to_string(),
expires_at: Utc::now() + Duration::hours(24),
permissions: vec!["read".to_string(), "write".to_string()],
}
}
/// Generate a task event
pub fn task_event(&self, task_id: String, event_type: TaskEventType) -> TaskEvent {
TaskEvent {
event_type,
task_id,
task: Some(self.workflow_task()),
timestamp: Utc::now(),
metadata: HashMap::new(),
}
}
}
impl Default for TestDataGenerator {
fn default() -> Self {
Self::new()
}
}
/// Test environment setup and cleanup utilities
pub struct TestEnvironment {
pub temp_dirs: Vec<TempDir>,
pub generator: TestDataGenerator,
}
impl TestEnvironment {
pub fn new() -> Self {
Self {
temp_dirs: Vec::new(),
generator: TestDataGenerator::new(),
}
}
/// Create a temporary directory for testing
pub fn create_temp_dir(&mut self) -> std::io::Result<PathBuf> {
let temp_dir = TempDir::new()?;
let path = temp_dir.path().to_path_buf();
self.temp_dirs.push(temp_dir);
Ok(path)
}
/// Create storage configuration for filesystem backend
pub fn filesystem_config(&mut self) -> std::io::Result<StorageConfig> {
let data_dir = self.create_temp_dir()?;
Ok(StorageConfig {
storage_type: "filesystem".to_string(),
data_dir: data_dir.to_string_lossy().to_string(),
..Default::default()
})
}
/// Create storage configuration for SurrealDB embedded backend
#[cfg(feature = "surrealdb")]
pub fn surrealdb_embedded_config(&mut self) -> std::io::Result<StorageConfig> {
let data_dir = self.create_temp_dir()?;
Ok(StorageConfig {
storage_type: "surrealdb-embedded".to_string(),
data_dir: data_dir.to_string_lossy().to_string(),
surrealdb_namespace: Some("test_orchestrator".to_string()),
surrealdb_database: Some("test_tasks".to_string()),
..Default::default()
})
}
/// Create storage configuration for SurrealDB server backend
#[cfg(feature = "surrealdb")]
pub fn surrealdb_server_config(&mut self) -> StorageConfig {
StorageConfig {
storage_type: "surrealdb-server".to_string(),
data_dir: "".to_string(),
surrealdb_url: Some("memory://test".to_string()),
surrealdb_namespace: Some("test_orchestrator".to_string()),
surrealdb_database: Some("test_tasks".to_string()),
surrealdb_username: Some("test".to_string()),
surrealdb_password: Some("test".to_string()),
}
}
/// Get all available storage configurations for testing
pub fn all_storage_configs(&mut self) -> Vec<StorageConfig> {
let mut configs = Vec::new();
// Always include filesystem
if let Ok(fs_config) = self.filesystem_config() {
configs.push(fs_config);
}
// Include SurrealDB backends if feature is enabled
#[cfg(feature = "surrealdb")]
{
if let Ok(embedded_config) = self.surrealdb_embedded_config() {
configs.push(embedded_config);
}
// Note: Server config using memory:// for testing
configs.push(self.surrealdb_server_config());
}
configs
}
}
impl Default for TestEnvironment {
fn default() -> Self {
Self::new()
}
}
/// Storage test runner for executing generic tests across all backends
pub struct StorageTestRunner {
env: TestEnvironment,
}
impl StorageTestRunner {
pub fn new() -> Self {
Self {
env: TestEnvironment::new(),
}
}
/// Run a test against all available storage backends
pub async fn run_against_all_backends<F, Fut>(&mut self, test_fn: F)
where
F: Fn(Box<dyn TaskStorage>, TestDataGenerator) -> Fut + Clone,
Fut: std::future::Future<Output = StorageResult<()>>,
{
let configs = self.env.all_storage_configs();
for config in configs {
println!("Testing with {} backend", config.storage_type);
let storage = create_storage(config.clone()).await
.expect(&format!("Failed to create {} storage", config.storage_type));
storage.init().await
.expect(&format!("Failed to initialize {} storage", config.storage_type));
let result = test_fn(storage, self.env.generator.clone()).await;
if let Err(e) = result {
panic!("Test failed for {} backend: {}", config.storage_type, e);
}
println!("{} backend passed", config.storage_type);
}
}
/// Run a test against specific storage backend
pub async fn run_against_backend<F, Fut>(&mut self, backend_type: &str, test_fn: F)
where
F: Fn(Box<dyn TaskStorage>, TestDataGenerator) -> Fut,
Fut: std::future::Future<Output = StorageResult<()>>,
{
let config = match backend_type {
"filesystem" => self.env.filesystem_config().expect("Failed to create filesystem config"),
#[cfg(feature = "surrealdb")]
"surrealdb-embedded" => self.env.surrealdb_embedded_config().expect("Failed to create SurrealDB embedded config"),
#[cfg(feature = "surrealdb")]
"surrealdb-server" => self.env.surrealdb_server_config(),
_ => panic!("Unsupported backend type: {}", backend_type),
};
let storage = create_storage(config).await
.expect(&format!("Failed to create {} storage", backend_type));
storage.init().await
.expect(&format!("Failed to initialize {} storage", backend_type));
let result = test_fn(storage, self.env.generator.clone()).await;
if let Err(e) = result {
panic!("Test failed for {} backend: {}", backend_type, e);
}
}
}
impl Default for StorageTestRunner {
fn default() -> Self {
Self::new()
}
}
/// Mock storage implementation for testing
pub struct MockStorage {
pub tasks: Arc<Mutex<HashMap<String, WorkflowTask>>>,
pub queue: Arc<Mutex<Vec<(WorkflowTask, u8)>>>,
pub audit_log: Arc<Mutex<Vec<AuditEntry>>>,
pub metrics: Arc<Mutex<Vec<Metric>>>,
pub auth_tokens: Arc<Mutex<HashMap<String, AuthToken>>>,
pub events: Arc<Mutex<Vec<TaskEvent>>>,
pub health: Arc<Mutex<bool>>,
pub call_count: Arc<Mutex<HashMap<String, usize>>>,
}
impl MockStorage {
pub fn new() -> Self {
Self {
tasks: Arc::new(Mutex::new(HashMap::new())),
queue: Arc::new(Mutex::new(Vec::new())),
audit_log: Arc::new(Mutex::new(Vec::new())),
metrics: Arc::new(Mutex::new(Vec::new())),
auth_tokens: Arc::new(Mutex::new(HashMap::new())),
events: Arc::new(Mutex::new(Vec::new())),
health: Arc::new(Mutex::new(true)),
call_count: Arc::new(Mutex::new(HashMap::new())),
}
}
/// Set health status for testing
pub fn set_health(&self, healthy: bool) {
*self.health.lock().unwrap() = healthy;
}
/// Get call count for a method
pub fn get_call_count(&self, method: &str) -> usize {
self.call_count.lock().unwrap().get(method).copied().unwrap_or(0)
}
/// Increment call count for a method
fn increment_call_count(&self, method: &str) {
let mut counts = self.call_count.lock().unwrap();
*counts.entry(method.to_string()).or_insert(0) += 1;
}
/// Clear all data (for test isolation)
pub fn clear(&self) {
self.tasks.lock().unwrap().clear();
self.queue.lock().unwrap().clear();
self.audit_log.lock().unwrap().clear();
self.metrics.lock().unwrap().clear();
self.auth_tokens.lock().unwrap().clear();
self.events.lock().unwrap().clear();
self.call_count.lock().unwrap().clear();
}
}
impl Default for MockStorage {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl TaskStorage for MockStorage {
async fn init(&self) -> StorageResult<()> {
self.increment_call_count("init");
Ok(())
}
async fn health_check(&self) -> StorageResult<bool> {
self.increment_call_count("health_check");
Ok(*self.health.lock().unwrap())
}
async fn enqueue(&self, task: WorkflowTask, priority: u8) -> StorageResult<()> {
self.increment_call_count("enqueue");
let mut queue = self.queue.lock().unwrap();
let mut tasks = self.tasks.lock().unwrap();
tasks.insert(task.id.clone(), task.clone());
queue.push((task, priority));
queue.sort_by(|a, b| b.1.cmp(&a.1)); // Sort by priority (descending)
Ok(())
}
async fn dequeue(&self) -> StorageResult<Option<WorkflowTask>> {
self.increment_call_count("dequeue");
let mut queue = self.queue.lock().unwrap();
if let Some((task, _)) = queue.pop() {
Ok(Some(task))
} else {
Ok(None)
}
}
async fn peek(&self) -> StorageResult<Option<WorkflowTask>> {
self.increment_call_count("peek");
let queue = self.queue.lock().unwrap();
if let Some((task, _)) = queue.last() {
Ok(Some(task.clone()))
} else {
Ok(None)
}
}
async fn get_task(&self, id: &str) -> StorageResult<Option<WorkflowTask>> {
self.increment_call_count("get_task");
let tasks = self.tasks.lock().unwrap();
Ok(tasks.get(id).cloned())
}
async fn update_task(&self, task: WorkflowTask) -> StorageResult<()> {
self.increment_call_count("update_task");
let mut tasks = self.tasks.lock().unwrap();
tasks.insert(task.id.clone(), task);
Ok(())
}
async fn update_task_status(&self, id: &str, status: TaskStatus) -> StorageResult<()> {
self.increment_call_count("update_task_status");
let mut tasks = self.tasks.lock().unwrap();
if let Some(task) = tasks.get_mut(id) {
task.status = status;
Ok(())
} else {
Err(StorageError::TaskNotFound { id: id.to_string() })
}
}
async fn list_tasks(&self, status_filter: Option<TaskStatus>) -> StorageResult<Vec<WorkflowTask>> {
self.increment_call_count("list_tasks");
let tasks = self.tasks.lock().unwrap();
let filtered_tasks: Vec<WorkflowTask> = tasks
.values()
.filter(|task| {
status_filter.as_ref().map_or(true, |status| &task.status == status)
})
.cloned()
.collect();
Ok(filtered_tasks)
}
async fn requeue_failed_task(&self, id: &str) -> StorageResult<bool> {
self.increment_call_count("requeue_failed_task");
let mut tasks = self.tasks.lock().unwrap();
let mut queue = self.queue.lock().unwrap();
if let Some(task) = tasks.get_mut(id) {
if task.status == TaskStatus::Failed {
task.status = TaskStatus::Pending;
queue.push((task.clone(), 1));
queue.sort_by(|a, b| b.1.cmp(&a.1));
Ok(true)
} else {
Ok(false)
}
} else {
Ok(false)
}
}
async fn queue_size(&self) -> StorageResult<usize> {
self.increment_call_count("queue_size");
let queue = self.queue.lock().unwrap();
Ok(queue.len())
}
async fn total_tasks(&self) -> StorageResult<usize> {
self.increment_call_count("total_tasks");
let tasks = self.tasks.lock().unwrap();
Ok(tasks.len())
}
async fn cleanup_completed_tasks(&self, older_than: Duration) -> StorageResult<usize> {
self.increment_call_count("cleanup_completed_tasks");
let mut tasks = self.tasks.lock().unwrap();
let cutoff = Utc::now() - older_than;
let initial_count = tasks.len();
tasks.retain(|_, task| {
!(task.status == TaskStatus::Completed &&
task.completed_at.map_or(false, |t| t < cutoff))
});
Ok(initial_count - tasks.len())
}
async fn get_audit_log(&self, _time_range: TimeRange) -> StorageResult<Vec<AuditEntry>> {
self.increment_call_count("get_audit_log");
let audit_log = self.audit_log.lock().unwrap();
Ok(audit_log.clone())
}
async fn record_audit_entry(&self, entry: AuditEntry) -> StorageResult<()> {
self.increment_call_count("record_audit_entry");
let mut audit_log = self.audit_log.lock().unwrap();
audit_log.push(entry);
Ok(())
}
async fn get_metrics(&self, _time_range: TimeRange) -> StorageResult<Vec<Metric>> {
self.increment_call_count("get_metrics");
let metrics = self.metrics.lock().unwrap();
Ok(metrics.clone())
}
async fn record_metric(&self, metric: Metric) -> StorageResult<()> {
self.increment_call_count("record_metric");
let mut metrics = self.metrics.lock().unwrap();
metrics.push(metric);
Ok(())
}
async fn authenticate(&self, username: &str, password: &str) -> StorageResult<AuthToken> {
self.increment_call_count("authenticate");
if username == "test" && password == "test" {
let token = AuthToken {
token: Uuid::new_v4().to_string(),
user_id: username.to_string(),
expires_at: Utc::now() + Duration::hours(24),
permissions: vec!["read".to_string(), "write".to_string()],
};
let mut tokens = self.auth_tokens.lock().unwrap();
tokens.insert(token.token.clone(), token.clone());
Ok(token)
} else {
Err(StorageError::AuthenticationFailed {
reason: "Invalid credentials".to_string()
})
}
}
async fn validate_token(&self, token: &str) -> StorageResult<Option<AuthToken>> {
self.increment_call_count("validate_token");
let tokens = self.auth_tokens.lock().unwrap();
if let Some(auth_token) = tokens.get(token) {
if !auth_token.is_expired() {
Ok(Some(auth_token.clone()))
} else {
Ok(None)
}
} else {
Ok(None)
}
}
async fn subscribe_to_events(&self, _filters: Option<Vec<TaskEventType>>) -> StorageResult<BoxStream<'_, TaskEvent>> {
self.increment_call_count("subscribe_to_events");
let events = self.events.lock().unwrap().clone();
Ok(stream::iter(events).boxed())
}
async fn publish_event(&self, event: TaskEvent) -> StorageResult<()> {
self.increment_call_count("publish_event");
let mut events = self.events.lock().unwrap();
events.push(event);
Ok(())
}
async fn search_tasks(
&self,
_name_pattern: Option<String>,
status_filter: Option<Vec<TaskStatus>>,
_created_after: Option<chrono::DateTime<chrono::Utc>>,
_created_before: Option<chrono::DateTime<chrono::Utc>>,
_limit: Option<usize>,
_offset: Option<usize>,
) -> StorageResult<Vec<WorkflowTask>> {
self.increment_call_count("search_tasks");
let tasks = self.tasks.lock().unwrap();
let filtered_tasks: Vec<WorkflowTask> = tasks
.values()
.filter(|task| {
status_filter.as_ref().map_or(true, |statuses| statuses.contains(&task.status))
})
.cloned()
.collect();
Ok(filtered_tasks)
}
async fn get_task_dependencies(&self, task_id: &str) -> StorageResult<Vec<String>> {
self.increment_call_count("get_task_dependencies");
let tasks = self.tasks.lock().unwrap();
if let Some(task) = tasks.get(task_id) {
Ok(task.dependencies.clone())
} else {
Ok(vec![])
}
}
async fn get_dependent_tasks(&self, _task_id: &str) -> StorageResult<Vec<String>> {
self.increment_call_count("get_dependent_tasks");
// Mock implementation - in real storage this would find tasks that depend on given task
Ok(vec![])
}
async fn create_backup(&self, _backup_path: &str) -> StorageResult<()> {
self.increment_call_count("create_backup");
// Mock implementation
Ok(())
}
async fn restore_from_backup(&self, _backup_path: &str) -> StorageResult<()> {
self.increment_call_count("restore_from_backup");
// Mock implementation
Ok(())
}
async fn get_statistics(&self) -> StorageResult<StorageStatistics> {
self.increment_call_count("get_statistics");
let tasks = self.tasks.lock().unwrap();
let mut stats = StorageStatistics::new();
stats.total_tasks = tasks.len();
for task in tasks.values() {
match task.status {
TaskStatus::Pending => stats.pending_tasks += 1,
TaskStatus::Running => stats.running_tasks += 1,
TaskStatus::Completed => stats.completed_tasks += 1,
TaskStatus::Failed => stats.failed_tasks += 1,
_ => {}
}
}
Ok(stats)
}
}
/// Common test assertions for storage implementations
pub struct StorageAssertions;
impl StorageAssertions {
/// Assert task exists and has expected values
pub async fn assert_task_exists(
storage: &Box<dyn TaskStorage>,
task_id: &str,
expected_name: &str,
) -> StorageResult<()> {
let task = storage.get_task(task_id).await?
.ok_or_else(|| StorageError::TaskNotFound { id: task_id.to_string() })?;
assert_eq!(task.name, expected_name);
assert_eq!(task.id, task_id);
Ok(())
}
/// Assert task has expected status
pub async fn assert_task_status(
storage: &Box<dyn TaskStorage>,
task_id: &str,
expected_status: TaskStatus,
) -> StorageResult<()> {
let task = storage.get_task(task_id).await?
.ok_or_else(|| StorageError::TaskNotFound { id: task_id.to_string() })?;
assert_eq!(task.status, expected_status);
Ok(())
}
/// Assert storage has expected number of tasks
pub async fn assert_task_count(
storage: &Box<dyn TaskStorage>,
expected_count: usize,
) -> StorageResult<()> {
let total = storage.total_tasks().await?;
assert_eq!(total, expected_count);
Ok(())
}
/// Assert queue has expected size
pub async fn assert_queue_size(
storage: &Box<dyn TaskStorage>,
expected_size: usize,
) -> StorageResult<()> {
let size = storage.queue_size().await?;
assert_eq!(size, expected_size);
Ok(())
}
/// Assert storage health check passes
pub async fn assert_healthy(storage: &Box<dyn TaskStorage>) -> StorageResult<()> {
let healthy = storage.health_check().await?;
assert!(healthy, "Storage health check failed");
Ok(())
}
}
/// Macro for running a test across all available storage backends
#[macro_export]
macro_rules! test_all_backends {
($test_name:ident, $test_fn:expr) => {
#[tokio::test]
async fn $test_name() {
let mut runner = $crate::helpers::StorageTestRunner::new();
runner.run_against_all_backends($test_fn).await;
}
};
}
/// Macro for running a test with specific backend only when feature is available
#[macro_export]
macro_rules! test_with_backend {
($test_name:ident, $backend:expr, $test_fn:expr) => {
#[tokio::test]
async fn $test_name() {
let mut runner = $crate::helpers::StorageTestRunner::new();
runner.run_against_backend($backend, $test_fn).await;
}
};
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_data_generator() {
let gen = TestDataGenerator::new();
// Test task generation
let task1 = gen.workflow_task();
let task2 = gen.workflow_task();
assert_ne!(task1.id, task2.id);
assert_eq!(task1.status, TaskStatus::Pending);
// Test task with status
let running_task = gen.workflow_task_with_status(TaskStatus::Running);
assert_eq!(running_task.status, TaskStatus::Running);
assert!(running_task.started_at.is_some());
// Test batch generation
let batch = gen.workflow_tasks_batch(4);
assert_eq!(batch.len(), 4);
// Test unique IDs
let ids: std::collections::HashSet<_> = batch.iter().map(|t| &t.id).collect();
assert_eq!(ids.len(), 4);
}
#[tokio::test]
async fn test_mock_storage() {
let mock = MockStorage::new();
let gen = TestDataGenerator::new();
// Test init
mock.init().await.unwrap();
assert_eq!(mock.get_call_count("init"), 1);
// Test health check
assert!(mock.health_check().await.unwrap());
mock.set_health(false);
assert!(!mock.health_check().await.unwrap());
// Test enqueue/dequeue
let task = gen.workflow_task();
let task_id = task.id.clone();
mock.enqueue(task.clone(), 5).await.unwrap();
assert_eq!(mock.queue_size().await.unwrap(), 1);
assert_eq!(mock.total_tasks().await.unwrap(), 1);
let dequeued = mock.dequeue().await.unwrap();
assert!(dequeued.is_some());
assert_eq!(dequeued.unwrap().id, task_id);
assert_eq!(mock.queue_size().await.unwrap(), 0);
assert_eq!(mock.total_tasks().await.unwrap(), 1); // Task still exists in storage
}
}