prvng_platform/orchestrator/src/workflow.rs

710 lines
24 KiB
Rust
Raw Normal View History

2025-10-07 10:59:52 +01:00
//! Workflow execution engine for batch operations
//!
//! This module provides a configuration-driven batch workflow engine that integrates
//! with the existing storage abstraction. It supports dependency resolution,
//! parallel execution with limits, and mixed provider operations.
use anyhow::{Context, Result};
use futures::{stream, StreamExt};
use serde::{Deserialize, Serialize};
use std::{
collections::{HashMap, HashSet, VecDeque},
sync::Arc,
time::{Duration, Instant},
};
use tokio::sync::{RwLock, Semaphore};
use tracing::{debug, error, info};
use uuid::Uuid;
use crate::{
storage::TaskStorage,
TaskStatus, WorkflowTask,
};
/// Configuration for workflow execution
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkflowConfig {
/// Maximum number of parallel tasks
pub max_parallel_tasks: usize,
/// Timeout for individual tasks in seconds
pub task_timeout_seconds: u64,
/// Maximum retry attempts for failed tasks
pub max_retries: u8,
/// Delay between retries in seconds
pub retry_delay_seconds: u64,
/// Whether to fail fast on first error
pub fail_fast: bool,
/// Checkpoint interval in seconds
pub checkpoint_interval_seconds: u64,
}
impl Default for WorkflowConfig {
fn default() -> Self {
Self {
max_parallel_tasks: 4,
task_timeout_seconds: 3600, // 1 hour
max_retries: 3,
retry_delay_seconds: 30,
fail_fast: false,
checkpoint_interval_seconds: 300, // 5 minutes
}
}
}
/// KCL workflow definition parsed from configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkflowDefinition {
pub name: String,
pub description: Option<String>,
pub tasks: Vec<WorkflowTaskDefinition>,
pub config: Option<WorkflowConfig>,
}
/// Individual task definition in workflow
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkflowTaskDefinition {
pub name: String,
pub command: String,
pub args: Vec<String>,
pub dependencies: Vec<String>,
pub provider: Option<String>,
pub timeout_seconds: Option<u64>,
pub max_retries: Option<u8>,
pub environment: Option<HashMap<String, String>>,
pub metadata: Option<HashMap<String, String>>,
}
/// Execution state for a workflow
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkflowExecutionState {
pub workflow_id: String,
pub status: WorkflowStatus,
pub started_at: chrono::DateTime<chrono::Utc>,
pub completed_at: Option<chrono::DateTime<chrono::Utc>>,
pub task_states: HashMap<String, TaskExecutionState>,
pub execution_graph: DependencyGraph,
pub checkpoints: Vec<WorkflowCheckpoint>,
pub statistics: WorkflowStatistics,
}
/// Overall workflow status
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum WorkflowStatus {
Pending,
Running,
Completed,
Failed,
Cancelled,
Paused,
}
impl std::fmt::Display for WorkflowStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
WorkflowStatus::Pending => write!(f, "pending"),
WorkflowStatus::Running => write!(f, "running"),
WorkflowStatus::Completed => write!(f, "completed"),
WorkflowStatus::Failed => write!(f, "failed"),
WorkflowStatus::Cancelled => write!(f, "cancelled"),
WorkflowStatus::Paused => write!(f, "paused"),
}
}
}
/// Execution state for individual task
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskExecutionState {
pub task_id: String,
pub status: TaskStatus,
pub retry_count: u8,
pub started_at: Option<chrono::DateTime<chrono::Utc>>,
pub completed_at: Option<chrono::DateTime<chrono::Utc>>,
pub duration_ms: Option<u64>,
pub error: Option<String>,
}
/// Dependency graph for task ordering
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DependencyGraph {
pub nodes: HashSet<String>,
pub edges: HashMap<String, Vec<String>>,
pub reverse_edges: HashMap<String, Vec<String>>,
}
impl DependencyGraph {
/// Create new empty dependency graph
pub fn new() -> Self {
Self {
nodes: HashSet::new(),
edges: HashMap::new(),
reverse_edges: HashMap::new(),
}
}
/// Add a task node to the graph
pub fn add_node(&mut self, task_name: String) {
self.nodes.insert(task_name.clone());
self.edges.entry(task_name.clone()).or_default();
self.reverse_edges.entry(task_name).or_default();
}
/// Add a dependency edge (from depends on to)
pub fn add_dependency(&mut self, from: String, depends_on: String) {
self.edges.entry(from.clone()).or_default().push(depends_on.clone());
self.reverse_edges.entry(depends_on).or_default().push(from);
}
/// Get tasks that have no pending dependencies
pub fn get_ready_tasks(&self, completed_tasks: &HashSet<String>) -> Vec<String> {
self.nodes
.iter()
.filter(|task| {
!completed_tasks.contains(*task) &&
self.edges
.get(*task)
.map(|deps| deps.iter().all(|dep| completed_tasks.contains(dep)))
.unwrap_or(true)
})
.cloned()
.collect()
}
/// Topologically sort tasks
pub fn topological_sort(&self) -> Result<Vec<String>> {
let mut in_degree: HashMap<String, usize> = HashMap::new();
let mut result = Vec::new();
let mut queue = VecDeque::new();
// Calculate in-degrees
for node in &self.nodes {
in_degree.insert(node.clone(), 0);
}
for (_, deps) in &self.edges {
for dep in deps {
*in_degree.entry(dep.clone()).or_default() += 1;
}
}
// Find nodes with no incoming edges
for (node, degree) in &in_degree {
if *degree == 0 {
queue.push_back(node.clone());
}
}
// Process queue
while let Some(node) = queue.pop_front() {
result.push(node.clone());
if let Some(dependents) = self.reverse_edges.get(&node) {
for dependent in dependents {
if let Some(degree) = in_degree.get_mut(dependent) {
*degree -= 1;
if *degree == 0 {
queue.push_back(dependent.clone());
}
}
}
}
}
if result.len() != self.nodes.len() {
return Err(anyhow::anyhow!("Circular dependency detected in workflow"));
}
Ok(result)
}
}
/// Workflow execution checkpoint for recovery
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkflowCheckpoint {
pub timestamp: chrono::DateTime<chrono::Utc>,
pub completed_tasks: HashSet<String>,
pub failed_tasks: HashSet<String>,
pub state_snapshot: String,
}
/// Workflow execution statistics
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkflowStatistics {
pub total_tasks: usize,
pub completed_tasks: usize,
pub failed_tasks: usize,
pub running_tasks: usize,
pub pending_tasks: usize,
pub total_duration_ms: u64,
pub average_task_duration_ms: Option<u64>,
pub throughput_tasks_per_minute: f64,
}
impl WorkflowStatistics {
pub fn new(total_tasks: usize) -> Self {
Self {
total_tasks,
completed_tasks: 0,
failed_tasks: 0,
running_tasks: 0,
pending_tasks: total_tasks,
total_duration_ms: 0,
average_task_duration_ms: None,
throughput_tasks_per_minute: 0.0,
}
}
}
/// Main workflow execution engine
pub struct BatchWorkflowEngine {
storage: Arc<dyn TaskStorage>,
config: WorkflowConfig,
execution_states: Arc<RwLock<HashMap<String, WorkflowExecutionState>>>,
}
impl BatchWorkflowEngine {
/// Create new workflow engine
pub fn new(storage: Arc<dyn TaskStorage>, config: WorkflowConfig) -> Self {
Self {
storage,
config,
execution_states: Arc::new(RwLock::new(HashMap::new())),
}
}
/// Execute a workflow definition
pub async fn execute_workflow(
&self,
definition: WorkflowDefinition,
) -> Result<WorkflowExecutionState> {
let workflow_id = Uuid::new_v4().to_string();
info!("Starting workflow execution: {} ({})", definition.name, workflow_id);
// Build dependency graph
let mut graph = DependencyGraph::new();
for task_def in &definition.tasks {
graph.add_node(task_def.name.clone());
}
for task_def in &definition.tasks {
for dep in &task_def.dependencies {
graph.add_dependency(task_def.name.clone(), dep.clone());
}
}
// Validate dependency graph
let _sorted_tasks = graph.topological_sort()
.context("Failed to resolve task dependencies")?;
// Create initial execution state
let mut execution_state = WorkflowExecutionState {
workflow_id: workflow_id.clone(),
status: WorkflowStatus::Running,
started_at: chrono::Utc::now(),
completed_at: None,
task_states: HashMap::new(),
execution_graph: graph,
checkpoints: Vec::new(),
statistics: WorkflowStatistics::new(definition.tasks.len()),
};
// Initialize task states
for task_def in &definition.tasks {
execution_state.task_states.insert(
task_def.name.clone(),
TaskExecutionState {
task_id: Uuid::new_v4().to_string(),
status: TaskStatus::Pending,
retry_count: 0,
started_at: None,
completed_at: None,
duration_ms: None,
error: None,
}
);
}
// Store initial state
{
let mut states = self.execution_states.write().await;
states.insert(workflow_id.clone(), execution_state.clone());
}
// Execute workflow
let final_state = self.execute_workflow_internal(
workflow_id.clone(),
definition,
execution_state,
).await?;
info!("Workflow execution completed: {} ({})",
final_state.workflow_id, final_state.status);
Ok(final_state)
}
/// Internal workflow execution logic
async fn execute_workflow_internal(
&self,
workflow_id: String,
definition: WorkflowDefinition,
mut state: WorkflowExecutionState,
) -> Result<WorkflowExecutionState> {
let semaphore = Arc::new(Semaphore::new(self.config.max_parallel_tasks));
let mut completed_tasks = HashSet::new();
let mut failed_tasks = HashSet::new();
let start_time = Instant::now();
let mut last_checkpoint = Instant::now();
loop {
// Check for checkpoint
if last_checkpoint.elapsed().as_secs() >= self.config.checkpoint_interval_seconds {
self.create_checkpoint(&mut state, &completed_tasks, &failed_tasks).await?;
last_checkpoint = Instant::now();
}
// Get ready tasks
let ready_tasks = state.execution_graph.get_ready_tasks(&completed_tasks);
let pending_ready_tasks: Vec<_> = ready_tasks
.into_iter()
.filter(|task| {
!completed_tasks.contains(task) &&
!failed_tasks.contains(task) &&
state.task_states.get(task)
.map(|ts| ts.status == TaskStatus::Pending)
.unwrap_or(false)
})
.collect();
if pending_ready_tasks.is_empty() {
// Check if all tasks are done
let total_done = completed_tasks.len() + failed_tasks.len();
if total_done >= definition.tasks.len() {
break;
}
// Wait for running tasks
tokio::time::sleep(Duration::from_secs(1)).await;
continue;
}
// Execute ready tasks in parallel
let tasks_futures = pending_ready_tasks.into_iter().map(|task_name| {
let task_def = definition.tasks.iter()
.find(|t| t.name == task_name)
.unwrap()
.clone();
let semaphore = semaphore.clone();
let storage = self.storage.clone();
let workflow_id = workflow_id.clone();
let config = self.config.clone();
async move {
let _permit = semaphore.acquire().await.unwrap();
self.execute_single_task(
workflow_id,
task_name.clone(),
task_def,
storage,
config,
).await
}
});
// Wait for batch completion
let results: Vec<_> = stream::iter(tasks_futures)
.buffer_unordered(self.config.max_parallel_tasks)
.collect()
.await;
// Process results
for (task_name, result) in results {
match result {
Ok(_) => {
completed_tasks.insert(task_name.clone());
if let Some(task_state) = state.task_states.get_mut(&task_name) {
task_state.status = TaskStatus::Completed;
task_state.completed_at = Some(chrono::Utc::now());
}
state.statistics.completed_tasks += 1;
state.statistics.pending_tasks -= 1;
}
Err(e) => {
error!("Task {} failed: {}", task_name, e);
failed_tasks.insert(task_name.clone());
if let Some(task_state) = state.task_states.get_mut(&task_name) {
task_state.status = TaskStatus::Failed;
task_state.completed_at = Some(chrono::Utc::now());
task_state.error = Some(e.to_string());
}
state.statistics.failed_tasks += 1;
state.statistics.pending_tasks -= 1;
if self.config.fail_fast {
state.status = WorkflowStatus::Failed;
return Ok(state);
}
}
}
}
// Update execution state
{
let mut states = self.execution_states.write().await;
states.insert(workflow_id.clone(), state.clone());
}
}
// Finalize workflow
state.completed_at = Some(chrono::Utc::now());
state.statistics.total_duration_ms = start_time.elapsed().as_millis() as u64;
if state.statistics.failed_tasks > 0 {
state.status = WorkflowStatus::Failed;
} else {
state.status = WorkflowStatus::Completed;
}
// Calculate final statistics
if state.statistics.completed_tasks > 0 {
let total_task_duration: u64 = state.task_states.values()
.filter_map(|ts| ts.duration_ms)
.sum();
state.statistics.average_task_duration_ms =
Some(total_task_duration / state.statistics.completed_tasks as u64);
}
let duration_minutes = state.statistics.total_duration_ms as f64 / 60000.0;
if duration_minutes > 0.0 {
state.statistics.throughput_tasks_per_minute =
state.statistics.completed_tasks as f64 / duration_minutes;
}
Ok(state)
}
/// Execute a single task
async fn execute_single_task(
&self,
workflow_id: String,
task_name: String,
task_def: WorkflowTaskDefinition,
storage: Arc<dyn TaskStorage>,
config: WorkflowConfig,
) -> (String, Result<()>) {
let task_start = Instant::now();
// Update task state to running
{
let mut states = self.execution_states.write().await;
if let Some(state) = states.get_mut(&workflow_id) {
if let Some(task_state) = state.task_states.get_mut(&task_name) {
task_state.status = TaskStatus::Running;
task_state.started_at = Some(chrono::Utc::now());
}
state.statistics.running_tasks += 1;
state.statistics.pending_tasks -= 1;
}
}
let result = self.execute_task_with_retry(
task_name.clone(),
task_def,
storage,
config,
).await;
// Update task duration
let duration_ms = task_start.elapsed().as_millis() as u64;
{
let mut states = self.execution_states.write().await;
if let Some(state) = states.get_mut(&workflow_id) {
if let Some(task_state) = state.task_states.get_mut(&task_name) {
task_state.duration_ms = Some(duration_ms);
}
state.statistics.running_tasks -= 1;
}
}
(task_name, result)
}
/// Execute task with retry logic
async fn execute_task_with_retry(
&self,
task_name: String,
task_def: WorkflowTaskDefinition,
storage: Arc<dyn TaskStorage>,
config: WorkflowConfig,
) -> Result<()> {
let max_retries = task_def.max_retries.unwrap_or(config.max_retries);
let mut last_error = None;
for retry in 0..=max_retries {
if retry > 0 {
info!("Retrying task {} (attempt {}/{})", task_name, retry + 1, max_retries + 1);
tokio::time::sleep(Duration::from_secs(config.retry_delay_seconds)).await;
}
// Create workflow task
let workflow_task = WorkflowTask {
id: Uuid::new_v4().to_string(),
name: task_name.clone(),
command: task_def.command.clone(),
args: task_def.args.clone(),
dependencies: task_def.dependencies.clone(),
status: TaskStatus::Pending,
created_at: chrono::Utc::now(),
started_at: None,
completed_at: None,
output: None,
error: None,
};
// Execute via storage backend
match storage.enqueue(workflow_task.clone(), 1).await {
Ok(_) => {
// Wait for completion
let timeout = Duration::from_secs(
task_def.timeout_seconds.unwrap_or(config.task_timeout_seconds)
);
match self.wait_for_task_completion(&workflow_task.id, timeout, storage.clone()).await {
Ok(_) => return Ok(()),
Err(e) => {
last_error = Some(e);
continue;
}
}
}
Err(e) => {
last_error = Some(e.into());
}
}
}
Err(last_error.unwrap_or_else(|| anyhow::anyhow!("Task execution failed after retries")))
}
/// Wait for task completion with timeout
async fn wait_for_task_completion(
&self,
task_id: &str,
timeout: Duration,
storage: Arc<dyn TaskStorage>,
) -> Result<WorkflowTask> {
let start = Instant::now();
let poll_interval = Duration::from_secs(1);
loop {
if start.elapsed() > timeout {
return Err(anyhow::anyhow!("Task execution timeout"));
}
match storage.get_task(task_id).await {
Ok(Some(task)) => {
match task.status {
TaskStatus::Completed => return Ok(task),
TaskStatus::Failed | TaskStatus::Cancelled => {
return Err(anyhow::anyhow!(
"Task failed: {}",
task.error.unwrap_or_else(|| "Unknown error".to_string())
));
}
_ => {
tokio::time::sleep(poll_interval).await;
}
}
}
Ok(None) => {
return Err(anyhow::anyhow!("Task not found"));
}
Err(e) => {
return Err(anyhow::anyhow!("Storage error: {}", e));
}
}
}
}
/// Create workflow checkpoint
async fn create_checkpoint(
&self,
state: &mut WorkflowExecutionState,
completed_tasks: &HashSet<String>,
failed_tasks: &HashSet<String>,
) -> Result<()> {
let checkpoint = WorkflowCheckpoint {
timestamp: chrono::Utc::now(),
completed_tasks: completed_tasks.clone(),
failed_tasks: failed_tasks.clone(),
state_snapshot: serde_json::to_string(state)?,
};
state.checkpoints.push(checkpoint);
// Keep only last 10 checkpoints
if state.checkpoints.len() > 10 {
state.checkpoints.drain(0..state.checkpoints.len() - 10);
}
debug!("Created workflow checkpoint for {}", state.workflow_id);
Ok(())
}
/// Get workflow execution state
pub async fn get_workflow_state(&self, workflow_id: &str) -> Option<WorkflowExecutionState> {
let states = self.execution_states.read().await;
states.get(workflow_id).cloned()
}
/// List all workflow states
pub async fn list_workflows(&self) -> Vec<WorkflowExecutionState> {
let states = self.execution_states.read().await;
states.values().cloned().collect()
}
/// Cancel a running workflow
pub async fn cancel_workflow(&self, workflow_id: &str) -> Result<()> {
let mut states = self.execution_states.write().await;
if let Some(state) = states.get_mut(workflow_id) {
if state.status == WorkflowStatus::Running {
state.status = WorkflowStatus::Cancelled;
state.completed_at = Some(chrono::Utc::now());
info!("Cancelled workflow: {}", workflow_id);
}
}
Ok(())
}
}
/// Parse KCL workflow definition from JSON
pub fn parse_kcl_workflow(json_content: &str) -> Result<WorkflowDefinition> {
serde_json::from_str(json_content)
.context("Failed to parse KCL workflow definition")
}
/// Validate workflow definition
pub fn validate_workflow_definition(definition: &WorkflowDefinition) -> Result<()> {
// Check for duplicate task names
let mut task_names = HashSet::new();
for task in &definition.tasks {
if !task_names.insert(&task.name) {
return Err(anyhow::anyhow!("Duplicate task name: {}", task.name));
}
}
// Check dependency references
for task in &definition.tasks {
for dep in &task.dependencies {
if !task_names.contains(dep) {
return Err(anyhow::anyhow!(
"Task '{}' has invalid dependency: '{}'",
task.name, dep
));
}
}
}
Ok(())
}