398 lines
10 KiB
Rust
Raw Normal View History

use std::env;
use std::path::Path;
use platform_config::ConfigLoader;
/// AI Service configuration
use serde::{Deserialize, Serialize};
/// Main AI Service configuration
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct AiServiceConfig {
/// Server configuration
#[serde(default)]
pub server: ServerConfig,
/// RAG integration configuration
#[serde(default)]
pub rag: RagIntegrationConfig,
/// MCP integration configuration
#[serde(default)]
pub mcp: McpIntegrationConfig,
/// DAG execution configuration
#[serde(default)]
pub dag: DagConfig,
}
/// Server configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerConfig {
/// Server bind address
#[serde(default = "default_host")]
pub host: String,
/// Server port
#[serde(default = "default_server_port")]
pub port: u16,
/// Number of worker threads
#[serde(default = "default_workers")]
pub workers: usize,
/// TCP keep-alive timeout (seconds)
#[serde(default = "default_keep_alive")]
pub keep_alive: u64,
/// Request timeout (milliseconds)
#[serde(default = "default_request_timeout")]
pub request_timeout: u64,
}
/// RAG integration configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RagIntegrationConfig {
/// Enable RAG integration
#[serde(default)]
pub enabled: bool,
/// RAG service URL
#[serde(default = "default_rag_url")]
pub rag_service_url: String,
/// Request timeout (milliseconds)
#[serde(default = "default_rag_timeout")]
pub timeout: u64,
/// Max retries for failed requests
#[serde(default = "default_max_retries")]
pub max_retries: u32,
/// Enable response caching
#[serde(default = "default_true")]
pub cache_enabled: bool,
}
/// MCP integration configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpIntegrationConfig {
/// Enable MCP integration
#[serde(default)]
pub enabled: bool,
/// MCP service URL
#[serde(default = "default_mcp_url")]
pub mcp_service_url: String,
/// Request timeout (milliseconds)
#[serde(default = "default_mcp_timeout")]
pub timeout: u64,
/// Max retries for failed requests
#[serde(default = "default_max_retries")]
pub max_retries: u32,
/// MCP protocol version
#[serde(default = "default_protocol_version")]
pub protocol_version: String,
}
/// DAG execution configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DagConfig {
/// Maximum concurrent tasks
#[serde(default = "default_max_concurrent_tasks")]
pub max_concurrent_tasks: usize,
/// Task timeout (milliseconds)
#[serde(default = "default_task_timeout")]
pub task_timeout: u64,
/// Number of retry attempts
#[serde(default = "default_dag_retry_attempts")]
pub retry_attempts: u32,
/// Delay between retries (milliseconds)
#[serde(default = "default_retry_delay")]
pub retry_delay: u64,
/// Task queue size
#[serde(default = "default_queue_size")]
pub queue_size: usize,
}
// Default value functions
fn default_host() -> String {
"127.0.0.1".to_string()
}
fn default_server_port() -> u16 {
8082
}
fn default_workers() -> usize {
4
}
fn default_keep_alive() -> u64 {
75
}
fn default_request_timeout() -> u64 {
30000
}
fn default_rag_url() -> String {
"http://localhost:8083".to_string()
}
fn default_rag_timeout() -> u64 {
30000
}
fn default_mcp_url() -> String {
"http://localhost:8084".to_string()
}
fn default_mcp_timeout() -> u64 {
30000
}
fn default_max_retries() -> u32 {
3
}
fn default_true() -> bool {
true
}
fn default_protocol_version() -> String {
"1.0".to_string()
}
fn default_max_concurrent_tasks() -> usize {
10
}
fn default_task_timeout() -> u64 {
600000
}
fn default_dag_retry_attempts() -> u32 {
3
}
fn default_retry_delay() -> u64 {
1000
}
fn default_queue_size() -> usize {
1000
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
host: default_host(),
port: default_server_port(),
workers: default_workers(),
keep_alive: default_keep_alive(),
request_timeout: default_request_timeout(),
}
}
}
impl Default for RagIntegrationConfig {
fn default() -> Self {
Self {
enabled: false,
rag_service_url: default_rag_url(),
timeout: default_rag_timeout(),
max_retries: default_max_retries(),
cache_enabled: default_true(),
}
}
}
impl Default for McpIntegrationConfig {
fn default() -> Self {
Self {
enabled: false,
mcp_service_url: default_mcp_url(),
timeout: default_mcp_timeout(),
max_retries: default_max_retries(),
protocol_version: default_protocol_version(),
}
}
}
impl Default for DagConfig {
fn default() -> Self {
Self {
max_concurrent_tasks: default_max_concurrent_tasks(),
task_timeout: default_task_timeout(),
retry_attempts: default_dag_retry_attempts(),
retry_delay: default_retry_delay(),
queue_size: default_queue_size(),
}
}
}
impl ConfigLoader for AiServiceConfig {
fn service_name() -> &'static str {
"ai-service"
}
fn load_from_hierarchy() -> std::result::Result<Self, Box<dyn std::error::Error + Send + Sync>>
{
let service = Self::service_name();
if let Some(path) = platform_config::resolve_config_path(service) {
return Self::from_path(&path);
}
// Fallback to defaults
Ok(Self::default())
}
fn apply_env_overrides(
&mut self,
) -> std::result::Result<(), Box<dyn std::error::Error + Send + Sync>> {
Self::apply_env_overrides_internal(self);
Ok(())
}
fn from_path<P: AsRef<Path>>(
path: P,
) -> std::result::Result<Self, Box<dyn std::error::Error + Send + Sync>> {
let path = path.as_ref();
let json_value = platform_config::format::load_config(path).map_err(|e| {
let err: Box<dyn std::error::Error + Send + Sync> = Box::new(e);
err
})?;
serde_json::from_value(json_value).map_err(|e| {
let err_msg = format!(
"Failed to deserialize AI service config from {:?}: {}",
path, e
);
Box::new(std::io::Error::new(
std::io::ErrorKind::InvalidData,
err_msg,
)) as Box<dyn std::error::Error + Send + Sync>
})
}
}
impl AiServiceConfig {
/// Load configuration from hierarchical sources with mode support
///
/// Priority order:
/// 1. AI_SERVICE_CONFIG environment variable (explicit path)
/// 2. AI_SERVICE_MODE environment variable (mode-specific file)
/// 3. Default configuration
///
/// After loading, applies environment variable overrides.
pub fn load_from_hierarchy() -> Result<Self, Box<dyn std::error::Error>> {
<Self as ConfigLoader>::load_from_hierarchy().map_err(|_e| {
Box::new(std::io::Error::other("Failed to load AI service config"))
as Box<dyn std::error::Error>
})
}
/// Internal: Apply environment variable overrides (mutable reference)
///
/// Overrides take precedence over loaded config values.
/// Pattern: AI_SERVICE_{SECTION}_{KEY}
fn apply_env_overrides_internal(config: &mut Self) {
// Server overrides
if let Ok(val) = env::var("AI_SERVICE_SERVER_HOST") {
config.server.host = val;
}
if let Ok(val) = env::var("AI_SERVICE_SERVER_PORT") {
if let Ok(port) = val.parse() {
config.server.port = port;
}
}
if let Ok(val) = env::var("AI_SERVICE_SERVER_WORKERS") {
if let Ok(workers) = val.parse() {
config.server.workers = workers;
}
}
// RAG integration overrides
if let Ok(val) = env::var("AI_SERVICE_RAG_ENABLED") {
config.rag.enabled = val.parse().unwrap_or(config.rag.enabled);
}
if let Ok(val) = env::var("AI_SERVICE_RAG_URL") {
config.rag.rag_service_url = val;
}
if let Ok(val) = env::var("AI_SERVICE_RAG_TIMEOUT") {
if let Ok(timeout) = val.parse() {
config.rag.timeout = timeout;
}
}
// MCP integration overrides
if let Ok(val) = env::var("AI_SERVICE_MCP_ENABLED") {
config.mcp.enabled = val.parse().unwrap_or(config.mcp.enabled);
}
if let Ok(val) = env::var("AI_SERVICE_MCP_URL") {
config.mcp.mcp_service_url = val;
}
if let Ok(val) = env::var("AI_SERVICE_MCP_TIMEOUT") {
if let Ok(timeout) = val.parse() {
config.mcp.timeout = timeout;
}
}
// DAG overrides
if let Ok(val) = env::var("AI_SERVICE_DAG_MAX_CONCURRENT_TASKS") {
if let Ok(tasks) = val.parse() {
config.dag.max_concurrent_tasks = tasks;
}
}
if let Ok(val) = env::var("AI_SERVICE_DAG_TASK_TIMEOUT") {
if let Ok(timeout) = val.parse() {
config.dag.task_timeout = timeout;
}
}
if let Ok(val) = env::var("AI_SERVICE_DAG_RETRY_ATTEMPTS") {
if let Ok(retries) = val.parse() {
config.dag.retry_attempts = retries;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = AiServiceConfig::default();
assert_eq!(config.server.port, 8082);
assert_eq!(config.server.workers, 4);
assert!(!config.rag.enabled);
assert!(!config.mcp.enabled);
assert_eq!(config.dag.max_concurrent_tasks, 10);
}
#[test]
fn test_server_config_defaults() {
let server = ServerConfig::default();
assert_eq!(server.host, "127.0.0.1");
assert_eq!(server.port, 8082);
assert_eq!(server.workers, 4);
}
#[test]
fn test_dag_config_defaults() {
let dag = DagConfig::default();
assert_eq!(dag.max_concurrent_tasks, 10);
assert_eq!(dag.task_timeout, 600000);
assert_eq!(dag.retry_attempts, 3);
}
}