710 lines
24 KiB
Rust
710 lines
24 KiB
Rust
|
|
//! Workflow execution engine for batch operations
|
||
|
|
//!
|
||
|
|
//! This module provides a configuration-driven batch workflow engine that integrates
|
||
|
|
//! with the existing storage abstraction. It supports dependency resolution,
|
||
|
|
//! parallel execution with limits, and mixed provider operations.
|
||
|
|
|
||
|
|
use anyhow::{Context, Result};
|
||
|
|
use futures::{stream, StreamExt};
|
||
|
|
use serde::{Deserialize, Serialize};
|
||
|
|
use std::{
|
||
|
|
collections::{HashMap, HashSet, VecDeque},
|
||
|
|
sync::Arc,
|
||
|
|
time::{Duration, Instant},
|
||
|
|
};
|
||
|
|
use tokio::sync::{RwLock, Semaphore};
|
||
|
|
use tracing::{debug, error, info};
|
||
|
|
use uuid::Uuid;
|
||
|
|
|
||
|
|
use crate::{
|
||
|
|
storage::TaskStorage,
|
||
|
|
TaskStatus, WorkflowTask,
|
||
|
|
};
|
||
|
|
|
||
|
|
/// Configuration for workflow execution
|
||
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
|
|
pub struct WorkflowConfig {
|
||
|
|
/// Maximum number of parallel tasks
|
||
|
|
pub max_parallel_tasks: usize,
|
||
|
|
/// Timeout for individual tasks in seconds
|
||
|
|
pub task_timeout_seconds: u64,
|
||
|
|
/// Maximum retry attempts for failed tasks
|
||
|
|
pub max_retries: u8,
|
||
|
|
/// Delay between retries in seconds
|
||
|
|
pub retry_delay_seconds: u64,
|
||
|
|
/// Whether to fail fast on first error
|
||
|
|
pub fail_fast: bool,
|
||
|
|
/// Checkpoint interval in seconds
|
||
|
|
pub checkpoint_interval_seconds: u64,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl Default for WorkflowConfig {
|
||
|
|
fn default() -> Self {
|
||
|
|
Self {
|
||
|
|
max_parallel_tasks: 4,
|
||
|
|
task_timeout_seconds: 3600, // 1 hour
|
||
|
|
max_retries: 3,
|
||
|
|
retry_delay_seconds: 30,
|
||
|
|
fail_fast: false,
|
||
|
|
checkpoint_interval_seconds: 300, // 5 minutes
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// KCL workflow definition parsed from configuration
|
||
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
|
|
pub struct WorkflowDefinition {
|
||
|
|
pub name: String,
|
||
|
|
pub description: Option<String>,
|
||
|
|
pub tasks: Vec<WorkflowTaskDefinition>,
|
||
|
|
pub config: Option<WorkflowConfig>,
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Individual task definition in workflow
|
||
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
|
|
pub struct WorkflowTaskDefinition {
|
||
|
|
pub name: String,
|
||
|
|
pub command: String,
|
||
|
|
pub args: Vec<String>,
|
||
|
|
pub dependencies: Vec<String>,
|
||
|
|
pub provider: Option<String>,
|
||
|
|
pub timeout_seconds: Option<u64>,
|
||
|
|
pub max_retries: Option<u8>,
|
||
|
|
pub environment: Option<HashMap<String, String>>,
|
||
|
|
pub metadata: Option<HashMap<String, String>>,
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Execution state for a workflow
|
||
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
|
|
pub struct WorkflowExecutionState {
|
||
|
|
pub workflow_id: String,
|
||
|
|
pub status: WorkflowStatus,
|
||
|
|
pub started_at: chrono::DateTime<chrono::Utc>,
|
||
|
|
pub completed_at: Option<chrono::DateTime<chrono::Utc>>,
|
||
|
|
pub task_states: HashMap<String, TaskExecutionState>,
|
||
|
|
pub execution_graph: DependencyGraph,
|
||
|
|
pub checkpoints: Vec<WorkflowCheckpoint>,
|
||
|
|
pub statistics: WorkflowStatistics,
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Overall workflow status
|
||
|
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||
|
|
pub enum WorkflowStatus {
|
||
|
|
Pending,
|
||
|
|
Running,
|
||
|
|
Completed,
|
||
|
|
Failed,
|
||
|
|
Cancelled,
|
||
|
|
Paused,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl std::fmt::Display for WorkflowStatus {
|
||
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||
|
|
match self {
|
||
|
|
WorkflowStatus::Pending => write!(f, "pending"),
|
||
|
|
WorkflowStatus::Running => write!(f, "running"),
|
||
|
|
WorkflowStatus::Completed => write!(f, "completed"),
|
||
|
|
WorkflowStatus::Failed => write!(f, "failed"),
|
||
|
|
WorkflowStatus::Cancelled => write!(f, "cancelled"),
|
||
|
|
WorkflowStatus::Paused => write!(f, "paused"),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Execution state for individual task
|
||
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
|
|
pub struct TaskExecutionState {
|
||
|
|
pub task_id: String,
|
||
|
|
pub status: TaskStatus,
|
||
|
|
pub retry_count: u8,
|
||
|
|
pub started_at: Option<chrono::DateTime<chrono::Utc>>,
|
||
|
|
pub completed_at: Option<chrono::DateTime<chrono::Utc>>,
|
||
|
|
pub duration_ms: Option<u64>,
|
||
|
|
pub error: Option<String>,
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Dependency graph for task ordering
|
||
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
|
|
pub struct DependencyGraph {
|
||
|
|
pub nodes: HashSet<String>,
|
||
|
|
pub edges: HashMap<String, Vec<String>>,
|
||
|
|
pub reverse_edges: HashMap<String, Vec<String>>,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl DependencyGraph {
|
||
|
|
/// Create new empty dependency graph
|
||
|
|
pub fn new() -> Self {
|
||
|
|
Self {
|
||
|
|
nodes: HashSet::new(),
|
||
|
|
edges: HashMap::new(),
|
||
|
|
reverse_edges: HashMap::new(),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Add a task node to the graph
|
||
|
|
pub fn add_node(&mut self, task_name: String) {
|
||
|
|
self.nodes.insert(task_name.clone());
|
||
|
|
self.edges.entry(task_name.clone()).or_default();
|
||
|
|
self.reverse_edges.entry(task_name).or_default();
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Add a dependency edge (from depends on to)
|
||
|
|
pub fn add_dependency(&mut self, from: String, depends_on: String) {
|
||
|
|
self.edges.entry(from.clone()).or_default().push(depends_on.clone());
|
||
|
|
self.reverse_edges.entry(depends_on).or_default().push(from);
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Get tasks that have no pending dependencies
|
||
|
|
pub fn get_ready_tasks(&self, completed_tasks: &HashSet<String>) -> Vec<String> {
|
||
|
|
self.nodes
|
||
|
|
.iter()
|
||
|
|
.filter(|task| {
|
||
|
|
!completed_tasks.contains(*task) &&
|
||
|
|
self.edges
|
||
|
|
.get(*task)
|
||
|
|
.map(|deps| deps.iter().all(|dep| completed_tasks.contains(dep)))
|
||
|
|
.unwrap_or(true)
|
||
|
|
})
|
||
|
|
.cloned()
|
||
|
|
.collect()
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Topologically sort tasks
|
||
|
|
pub fn topological_sort(&self) -> Result<Vec<String>> {
|
||
|
|
let mut in_degree: HashMap<String, usize> = HashMap::new();
|
||
|
|
let mut result = Vec::new();
|
||
|
|
let mut queue = VecDeque::new();
|
||
|
|
|
||
|
|
// Calculate in-degrees
|
||
|
|
for node in &self.nodes {
|
||
|
|
in_degree.insert(node.clone(), 0);
|
||
|
|
}
|
||
|
|
|
||
|
|
for (_, deps) in &self.edges {
|
||
|
|
for dep in deps {
|
||
|
|
*in_degree.entry(dep.clone()).or_default() += 1;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Find nodes with no incoming edges
|
||
|
|
for (node, degree) in &in_degree {
|
||
|
|
if *degree == 0 {
|
||
|
|
queue.push_back(node.clone());
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Process queue
|
||
|
|
while let Some(node) = queue.pop_front() {
|
||
|
|
result.push(node.clone());
|
||
|
|
|
||
|
|
if let Some(dependents) = self.reverse_edges.get(&node) {
|
||
|
|
for dependent in dependents {
|
||
|
|
if let Some(degree) = in_degree.get_mut(dependent) {
|
||
|
|
*degree -= 1;
|
||
|
|
if *degree == 0 {
|
||
|
|
queue.push_back(dependent.clone());
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
if result.len() != self.nodes.len() {
|
||
|
|
return Err(anyhow::anyhow!("Circular dependency detected in workflow"));
|
||
|
|
}
|
||
|
|
|
||
|
|
Ok(result)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Workflow execution checkpoint for recovery
|
||
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
|
|
pub struct WorkflowCheckpoint {
|
||
|
|
pub timestamp: chrono::DateTime<chrono::Utc>,
|
||
|
|
pub completed_tasks: HashSet<String>,
|
||
|
|
pub failed_tasks: HashSet<String>,
|
||
|
|
pub state_snapshot: String,
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Workflow execution statistics
|
||
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
|
|
pub struct WorkflowStatistics {
|
||
|
|
pub total_tasks: usize,
|
||
|
|
pub completed_tasks: usize,
|
||
|
|
pub failed_tasks: usize,
|
||
|
|
pub running_tasks: usize,
|
||
|
|
pub pending_tasks: usize,
|
||
|
|
pub total_duration_ms: u64,
|
||
|
|
pub average_task_duration_ms: Option<u64>,
|
||
|
|
pub throughput_tasks_per_minute: f64,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl WorkflowStatistics {
|
||
|
|
pub fn new(total_tasks: usize) -> Self {
|
||
|
|
Self {
|
||
|
|
total_tasks,
|
||
|
|
completed_tasks: 0,
|
||
|
|
failed_tasks: 0,
|
||
|
|
running_tasks: 0,
|
||
|
|
pending_tasks: total_tasks,
|
||
|
|
total_duration_ms: 0,
|
||
|
|
average_task_duration_ms: None,
|
||
|
|
throughput_tasks_per_minute: 0.0,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Main workflow execution engine
|
||
|
|
pub struct BatchWorkflowEngine {
|
||
|
|
storage: Arc<dyn TaskStorage>,
|
||
|
|
config: WorkflowConfig,
|
||
|
|
execution_states: Arc<RwLock<HashMap<String, WorkflowExecutionState>>>,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl BatchWorkflowEngine {
|
||
|
|
/// Create new workflow engine
|
||
|
|
pub fn new(storage: Arc<dyn TaskStorage>, config: WorkflowConfig) -> Self {
|
||
|
|
Self {
|
||
|
|
storage,
|
||
|
|
config,
|
||
|
|
execution_states: Arc::new(RwLock::new(HashMap::new())),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Execute a workflow definition
|
||
|
|
pub async fn execute_workflow(
|
||
|
|
&self,
|
||
|
|
definition: WorkflowDefinition,
|
||
|
|
) -> Result<WorkflowExecutionState> {
|
||
|
|
let workflow_id = Uuid::new_v4().to_string();
|
||
|
|
info!("Starting workflow execution: {} ({})", definition.name, workflow_id);
|
||
|
|
|
||
|
|
// Build dependency graph
|
||
|
|
let mut graph = DependencyGraph::new();
|
||
|
|
for task_def in &definition.tasks {
|
||
|
|
graph.add_node(task_def.name.clone());
|
||
|
|
}
|
||
|
|
|
||
|
|
for task_def in &definition.tasks {
|
||
|
|
for dep in &task_def.dependencies {
|
||
|
|
graph.add_dependency(task_def.name.clone(), dep.clone());
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Validate dependency graph
|
||
|
|
let _sorted_tasks = graph.topological_sort()
|
||
|
|
.context("Failed to resolve task dependencies")?;
|
||
|
|
|
||
|
|
// Create initial execution state
|
||
|
|
let mut execution_state = WorkflowExecutionState {
|
||
|
|
workflow_id: workflow_id.clone(),
|
||
|
|
status: WorkflowStatus::Running,
|
||
|
|
started_at: chrono::Utc::now(),
|
||
|
|
completed_at: None,
|
||
|
|
task_states: HashMap::new(),
|
||
|
|
execution_graph: graph,
|
||
|
|
checkpoints: Vec::new(),
|
||
|
|
statistics: WorkflowStatistics::new(definition.tasks.len()),
|
||
|
|
};
|
||
|
|
|
||
|
|
// Initialize task states
|
||
|
|
for task_def in &definition.tasks {
|
||
|
|
execution_state.task_states.insert(
|
||
|
|
task_def.name.clone(),
|
||
|
|
TaskExecutionState {
|
||
|
|
task_id: Uuid::new_v4().to_string(),
|
||
|
|
status: TaskStatus::Pending,
|
||
|
|
retry_count: 0,
|
||
|
|
started_at: None,
|
||
|
|
completed_at: None,
|
||
|
|
duration_ms: None,
|
||
|
|
error: None,
|
||
|
|
}
|
||
|
|
);
|
||
|
|
}
|
||
|
|
|
||
|
|
// Store initial state
|
||
|
|
{
|
||
|
|
let mut states = self.execution_states.write().await;
|
||
|
|
states.insert(workflow_id.clone(), execution_state.clone());
|
||
|
|
}
|
||
|
|
|
||
|
|
// Execute workflow
|
||
|
|
let final_state = self.execute_workflow_internal(
|
||
|
|
workflow_id.clone(),
|
||
|
|
definition,
|
||
|
|
execution_state,
|
||
|
|
).await?;
|
||
|
|
|
||
|
|
info!("Workflow execution completed: {} ({})",
|
||
|
|
final_state.workflow_id, final_state.status);
|
||
|
|
|
||
|
|
Ok(final_state)
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Internal workflow execution logic
|
||
|
|
async fn execute_workflow_internal(
|
||
|
|
&self,
|
||
|
|
workflow_id: String,
|
||
|
|
definition: WorkflowDefinition,
|
||
|
|
mut state: WorkflowExecutionState,
|
||
|
|
) -> Result<WorkflowExecutionState> {
|
||
|
|
let semaphore = Arc::new(Semaphore::new(self.config.max_parallel_tasks));
|
||
|
|
let mut completed_tasks = HashSet::new();
|
||
|
|
let mut failed_tasks = HashSet::new();
|
||
|
|
let start_time = Instant::now();
|
||
|
|
let mut last_checkpoint = Instant::now();
|
||
|
|
|
||
|
|
loop {
|
||
|
|
// Check for checkpoint
|
||
|
|
if last_checkpoint.elapsed().as_secs() >= self.config.checkpoint_interval_seconds {
|
||
|
|
self.create_checkpoint(&mut state, &completed_tasks, &failed_tasks).await?;
|
||
|
|
last_checkpoint = Instant::now();
|
||
|
|
}
|
||
|
|
|
||
|
|
// Get ready tasks
|
||
|
|
let ready_tasks = state.execution_graph.get_ready_tasks(&completed_tasks);
|
||
|
|
let pending_ready_tasks: Vec<_> = ready_tasks
|
||
|
|
.into_iter()
|
||
|
|
.filter(|task| {
|
||
|
|
!completed_tasks.contains(task) &&
|
||
|
|
!failed_tasks.contains(task) &&
|
||
|
|
state.task_states.get(task)
|
||
|
|
.map(|ts| ts.status == TaskStatus::Pending)
|
||
|
|
.unwrap_or(false)
|
||
|
|
})
|
||
|
|
.collect();
|
||
|
|
|
||
|
|
if pending_ready_tasks.is_empty() {
|
||
|
|
// Check if all tasks are done
|
||
|
|
let total_done = completed_tasks.len() + failed_tasks.len();
|
||
|
|
if total_done >= definition.tasks.len() {
|
||
|
|
break;
|
||
|
|
}
|
||
|
|
|
||
|
|
// Wait for running tasks
|
||
|
|
tokio::time::sleep(Duration::from_secs(1)).await;
|
||
|
|
continue;
|
||
|
|
}
|
||
|
|
|
||
|
|
// Execute ready tasks in parallel
|
||
|
|
let tasks_futures = pending_ready_tasks.into_iter().map(|task_name| {
|
||
|
|
let task_def = definition.tasks.iter()
|
||
|
|
.find(|t| t.name == task_name)
|
||
|
|
.unwrap()
|
||
|
|
.clone();
|
||
|
|
|
||
|
|
let semaphore = semaphore.clone();
|
||
|
|
let storage = self.storage.clone();
|
||
|
|
let workflow_id = workflow_id.clone();
|
||
|
|
let config = self.config.clone();
|
||
|
|
|
||
|
|
async move {
|
||
|
|
let _permit = semaphore.acquire().await.unwrap();
|
||
|
|
self.execute_single_task(
|
||
|
|
workflow_id,
|
||
|
|
task_name.clone(),
|
||
|
|
task_def,
|
||
|
|
storage,
|
||
|
|
config,
|
||
|
|
).await
|
||
|
|
}
|
||
|
|
});
|
||
|
|
|
||
|
|
// Wait for batch completion
|
||
|
|
let results: Vec<_> = stream::iter(tasks_futures)
|
||
|
|
.buffer_unordered(self.config.max_parallel_tasks)
|
||
|
|
.collect()
|
||
|
|
.await;
|
||
|
|
|
||
|
|
// Process results
|
||
|
|
for (task_name, result) in results {
|
||
|
|
match result {
|
||
|
|
Ok(_) => {
|
||
|
|
completed_tasks.insert(task_name.clone());
|
||
|
|
if let Some(task_state) = state.task_states.get_mut(&task_name) {
|
||
|
|
task_state.status = TaskStatus::Completed;
|
||
|
|
task_state.completed_at = Some(chrono::Utc::now());
|
||
|
|
}
|
||
|
|
state.statistics.completed_tasks += 1;
|
||
|
|
state.statistics.pending_tasks -= 1;
|
||
|
|
}
|
||
|
|
Err(e) => {
|
||
|
|
error!("Task {} failed: {}", task_name, e);
|
||
|
|
failed_tasks.insert(task_name.clone());
|
||
|
|
if let Some(task_state) = state.task_states.get_mut(&task_name) {
|
||
|
|
task_state.status = TaskStatus::Failed;
|
||
|
|
task_state.completed_at = Some(chrono::Utc::now());
|
||
|
|
task_state.error = Some(e.to_string());
|
||
|
|
}
|
||
|
|
state.statistics.failed_tasks += 1;
|
||
|
|
state.statistics.pending_tasks -= 1;
|
||
|
|
|
||
|
|
if self.config.fail_fast {
|
||
|
|
state.status = WorkflowStatus::Failed;
|
||
|
|
return Ok(state);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Update execution state
|
||
|
|
{
|
||
|
|
let mut states = self.execution_states.write().await;
|
||
|
|
states.insert(workflow_id.clone(), state.clone());
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Finalize workflow
|
||
|
|
state.completed_at = Some(chrono::Utc::now());
|
||
|
|
state.statistics.total_duration_ms = start_time.elapsed().as_millis() as u64;
|
||
|
|
|
||
|
|
if state.statistics.failed_tasks > 0 {
|
||
|
|
state.status = WorkflowStatus::Failed;
|
||
|
|
} else {
|
||
|
|
state.status = WorkflowStatus::Completed;
|
||
|
|
}
|
||
|
|
|
||
|
|
// Calculate final statistics
|
||
|
|
if state.statistics.completed_tasks > 0 {
|
||
|
|
let total_task_duration: u64 = state.task_states.values()
|
||
|
|
.filter_map(|ts| ts.duration_ms)
|
||
|
|
.sum();
|
||
|
|
|
||
|
|
state.statistics.average_task_duration_ms =
|
||
|
|
Some(total_task_duration / state.statistics.completed_tasks as u64);
|
||
|
|
}
|
||
|
|
|
||
|
|
let duration_minutes = state.statistics.total_duration_ms as f64 / 60000.0;
|
||
|
|
if duration_minutes > 0.0 {
|
||
|
|
state.statistics.throughput_tasks_per_minute =
|
||
|
|
state.statistics.completed_tasks as f64 / duration_minutes;
|
||
|
|
}
|
||
|
|
|
||
|
|
Ok(state)
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Execute a single task
|
||
|
|
async fn execute_single_task(
|
||
|
|
&self,
|
||
|
|
workflow_id: String,
|
||
|
|
task_name: String,
|
||
|
|
task_def: WorkflowTaskDefinition,
|
||
|
|
storage: Arc<dyn TaskStorage>,
|
||
|
|
config: WorkflowConfig,
|
||
|
|
) -> (String, Result<()>) {
|
||
|
|
let task_start = Instant::now();
|
||
|
|
|
||
|
|
// Update task state to running
|
||
|
|
{
|
||
|
|
let mut states = self.execution_states.write().await;
|
||
|
|
if let Some(state) = states.get_mut(&workflow_id) {
|
||
|
|
if let Some(task_state) = state.task_states.get_mut(&task_name) {
|
||
|
|
task_state.status = TaskStatus::Running;
|
||
|
|
task_state.started_at = Some(chrono::Utc::now());
|
||
|
|
}
|
||
|
|
state.statistics.running_tasks += 1;
|
||
|
|
state.statistics.pending_tasks -= 1;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
let result = self.execute_task_with_retry(
|
||
|
|
task_name.clone(),
|
||
|
|
task_def,
|
||
|
|
storage,
|
||
|
|
config,
|
||
|
|
).await;
|
||
|
|
|
||
|
|
// Update task duration
|
||
|
|
let duration_ms = task_start.elapsed().as_millis() as u64;
|
||
|
|
{
|
||
|
|
let mut states = self.execution_states.write().await;
|
||
|
|
if let Some(state) = states.get_mut(&workflow_id) {
|
||
|
|
if let Some(task_state) = state.task_states.get_mut(&task_name) {
|
||
|
|
task_state.duration_ms = Some(duration_ms);
|
||
|
|
}
|
||
|
|
state.statistics.running_tasks -= 1;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
(task_name, result)
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Execute task with retry logic
|
||
|
|
async fn execute_task_with_retry(
|
||
|
|
&self,
|
||
|
|
task_name: String,
|
||
|
|
task_def: WorkflowTaskDefinition,
|
||
|
|
storage: Arc<dyn TaskStorage>,
|
||
|
|
config: WorkflowConfig,
|
||
|
|
) -> Result<()> {
|
||
|
|
let max_retries = task_def.max_retries.unwrap_or(config.max_retries);
|
||
|
|
let mut last_error = None;
|
||
|
|
|
||
|
|
for retry in 0..=max_retries {
|
||
|
|
if retry > 0 {
|
||
|
|
info!("Retrying task {} (attempt {}/{})", task_name, retry + 1, max_retries + 1);
|
||
|
|
tokio::time::sleep(Duration::from_secs(config.retry_delay_seconds)).await;
|
||
|
|
}
|
||
|
|
|
||
|
|
// Create workflow task
|
||
|
|
let workflow_task = WorkflowTask {
|
||
|
|
id: Uuid::new_v4().to_string(),
|
||
|
|
name: task_name.clone(),
|
||
|
|
command: task_def.command.clone(),
|
||
|
|
args: task_def.args.clone(),
|
||
|
|
dependencies: task_def.dependencies.clone(),
|
||
|
|
status: TaskStatus::Pending,
|
||
|
|
created_at: chrono::Utc::now(),
|
||
|
|
started_at: None,
|
||
|
|
completed_at: None,
|
||
|
|
output: None,
|
||
|
|
error: None,
|
||
|
|
};
|
||
|
|
|
||
|
|
// Execute via storage backend
|
||
|
|
match storage.enqueue(workflow_task.clone(), 1).await {
|
||
|
|
Ok(_) => {
|
||
|
|
// Wait for completion
|
||
|
|
let timeout = Duration::from_secs(
|
||
|
|
task_def.timeout_seconds.unwrap_or(config.task_timeout_seconds)
|
||
|
|
);
|
||
|
|
|
||
|
|
match self.wait_for_task_completion(&workflow_task.id, timeout, storage.clone()).await {
|
||
|
|
Ok(_) => return Ok(()),
|
||
|
|
Err(e) => {
|
||
|
|
last_error = Some(e);
|
||
|
|
continue;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
Err(e) => {
|
||
|
|
last_error = Some(e.into());
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
Err(last_error.unwrap_or_else(|| anyhow::anyhow!("Task execution failed after retries")))
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Wait for task completion with timeout
|
||
|
|
async fn wait_for_task_completion(
|
||
|
|
&self,
|
||
|
|
task_id: &str,
|
||
|
|
timeout: Duration,
|
||
|
|
storage: Arc<dyn TaskStorage>,
|
||
|
|
) -> Result<WorkflowTask> {
|
||
|
|
let start = Instant::now();
|
||
|
|
let poll_interval = Duration::from_secs(1);
|
||
|
|
|
||
|
|
loop {
|
||
|
|
if start.elapsed() > timeout {
|
||
|
|
return Err(anyhow::anyhow!("Task execution timeout"));
|
||
|
|
}
|
||
|
|
|
||
|
|
match storage.get_task(task_id).await {
|
||
|
|
Ok(Some(task)) => {
|
||
|
|
match task.status {
|
||
|
|
TaskStatus::Completed => return Ok(task),
|
||
|
|
TaskStatus::Failed | TaskStatus::Cancelled => {
|
||
|
|
return Err(anyhow::anyhow!(
|
||
|
|
"Task failed: {}",
|
||
|
|
task.error.unwrap_or_else(|| "Unknown error".to_string())
|
||
|
|
));
|
||
|
|
}
|
||
|
|
_ => {
|
||
|
|
tokio::time::sleep(poll_interval).await;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
Ok(None) => {
|
||
|
|
return Err(anyhow::anyhow!("Task not found"));
|
||
|
|
}
|
||
|
|
Err(e) => {
|
||
|
|
return Err(anyhow::anyhow!("Storage error: {}", e));
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Create workflow checkpoint
|
||
|
|
async fn create_checkpoint(
|
||
|
|
&self,
|
||
|
|
state: &mut WorkflowExecutionState,
|
||
|
|
completed_tasks: &HashSet<String>,
|
||
|
|
failed_tasks: &HashSet<String>,
|
||
|
|
) -> Result<()> {
|
||
|
|
let checkpoint = WorkflowCheckpoint {
|
||
|
|
timestamp: chrono::Utc::now(),
|
||
|
|
completed_tasks: completed_tasks.clone(),
|
||
|
|
failed_tasks: failed_tasks.clone(),
|
||
|
|
state_snapshot: serde_json::to_string(state)?,
|
||
|
|
};
|
||
|
|
|
||
|
|
state.checkpoints.push(checkpoint);
|
||
|
|
|
||
|
|
// Keep only last 10 checkpoints
|
||
|
|
if state.checkpoints.len() > 10 {
|
||
|
|
state.checkpoints.drain(0..state.checkpoints.len() - 10);
|
||
|
|
}
|
||
|
|
|
||
|
|
debug!("Created workflow checkpoint for {}", state.workflow_id);
|
||
|
|
Ok(())
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Get workflow execution state
|
||
|
|
pub async fn get_workflow_state(&self, workflow_id: &str) -> Option<WorkflowExecutionState> {
|
||
|
|
let states = self.execution_states.read().await;
|
||
|
|
states.get(workflow_id).cloned()
|
||
|
|
}
|
||
|
|
|
||
|
|
/// List all workflow states
|
||
|
|
pub async fn list_workflows(&self) -> Vec<WorkflowExecutionState> {
|
||
|
|
let states = self.execution_states.read().await;
|
||
|
|
states.values().cloned().collect()
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Cancel a running workflow
|
||
|
|
pub async fn cancel_workflow(&self, workflow_id: &str) -> Result<()> {
|
||
|
|
let mut states = self.execution_states.write().await;
|
||
|
|
if let Some(state) = states.get_mut(workflow_id) {
|
||
|
|
if state.status == WorkflowStatus::Running {
|
||
|
|
state.status = WorkflowStatus::Cancelled;
|
||
|
|
state.completed_at = Some(chrono::Utc::now());
|
||
|
|
info!("Cancelled workflow: {}", workflow_id);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
Ok(())
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Parse KCL workflow definition from JSON
|
||
|
|
pub fn parse_kcl_workflow(json_content: &str) -> Result<WorkflowDefinition> {
|
||
|
|
serde_json::from_str(json_content)
|
||
|
|
.context("Failed to parse KCL workflow definition")
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Validate workflow definition
|
||
|
|
pub fn validate_workflow_definition(definition: &WorkflowDefinition) -> Result<()> {
|
||
|
|
// Check for duplicate task names
|
||
|
|
let mut task_names = HashSet::new();
|
||
|
|
for task in &definition.tasks {
|
||
|
|
if !task_names.insert(&task.name) {
|
||
|
|
return Err(anyhow::anyhow!("Duplicate task name: {}", task.name));
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Check dependency references
|
||
|
|
for task in &definition.tasks {
|
||
|
|
for dep in &task.dependencies {
|
||
|
|
if !task_names.contains(dep) {
|
||
|
|
return Err(anyhow::anyhow!(
|
||
|
|
"Task '{}' has invalid dependency: '{}'",
|
||
|
|
task.name, dep
|
||
|
|
));
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
Ok(())
|
||
|
|
}
|