398 lines
10 KiB
Rust
398 lines
10 KiB
Rust
|
|
use std::env;
|
||
|
|
use std::path::Path;
|
||
|
|
|
||
|
|
use platform_config::ConfigLoader;
|
||
|
|
/// AI Service configuration
|
||
|
|
use serde::{Deserialize, Serialize};
|
||
|
|
|
||
|
|
/// Main AI Service configuration
|
||
|
|
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||
|
|
pub struct AiServiceConfig {
|
||
|
|
/// Server configuration
|
||
|
|
#[serde(default)]
|
||
|
|
pub server: ServerConfig,
|
||
|
|
|
||
|
|
/// RAG integration configuration
|
||
|
|
#[serde(default)]
|
||
|
|
pub rag: RagIntegrationConfig,
|
||
|
|
|
||
|
|
/// MCP integration configuration
|
||
|
|
#[serde(default)]
|
||
|
|
pub mcp: McpIntegrationConfig,
|
||
|
|
|
||
|
|
/// DAG execution configuration
|
||
|
|
#[serde(default)]
|
||
|
|
pub dag: DagConfig,
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Server configuration
|
||
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
|
|
pub struct ServerConfig {
|
||
|
|
/// Server bind address
|
||
|
|
#[serde(default = "default_host")]
|
||
|
|
pub host: String,
|
||
|
|
|
||
|
|
/// Server port
|
||
|
|
#[serde(default = "default_server_port")]
|
||
|
|
pub port: u16,
|
||
|
|
|
||
|
|
/// Number of worker threads
|
||
|
|
#[serde(default = "default_workers")]
|
||
|
|
pub workers: usize,
|
||
|
|
|
||
|
|
/// TCP keep-alive timeout (seconds)
|
||
|
|
#[serde(default = "default_keep_alive")]
|
||
|
|
pub keep_alive: u64,
|
||
|
|
|
||
|
|
/// Request timeout (milliseconds)
|
||
|
|
#[serde(default = "default_request_timeout")]
|
||
|
|
pub request_timeout: u64,
|
||
|
|
}
|
||
|
|
|
||
|
|
/// RAG integration configuration
|
||
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
|
|
pub struct RagIntegrationConfig {
|
||
|
|
/// Enable RAG integration
|
||
|
|
#[serde(default)]
|
||
|
|
pub enabled: bool,
|
||
|
|
|
||
|
|
/// RAG service URL
|
||
|
|
#[serde(default = "default_rag_url")]
|
||
|
|
pub rag_service_url: String,
|
||
|
|
|
||
|
|
/// Request timeout (milliseconds)
|
||
|
|
#[serde(default = "default_rag_timeout")]
|
||
|
|
pub timeout: u64,
|
||
|
|
|
||
|
|
/// Max retries for failed requests
|
||
|
|
#[serde(default = "default_max_retries")]
|
||
|
|
pub max_retries: u32,
|
||
|
|
|
||
|
|
/// Enable response caching
|
||
|
|
#[serde(default = "default_true")]
|
||
|
|
pub cache_enabled: bool,
|
||
|
|
}
|
||
|
|
|
||
|
|
/// MCP integration configuration
|
||
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
|
|
pub struct McpIntegrationConfig {
|
||
|
|
/// Enable MCP integration
|
||
|
|
#[serde(default)]
|
||
|
|
pub enabled: bool,
|
||
|
|
|
||
|
|
/// MCP service URL
|
||
|
|
#[serde(default = "default_mcp_url")]
|
||
|
|
pub mcp_service_url: String,
|
||
|
|
|
||
|
|
/// Request timeout (milliseconds)
|
||
|
|
#[serde(default = "default_mcp_timeout")]
|
||
|
|
pub timeout: u64,
|
||
|
|
|
||
|
|
/// Max retries for failed requests
|
||
|
|
#[serde(default = "default_max_retries")]
|
||
|
|
pub max_retries: u32,
|
||
|
|
|
||
|
|
/// MCP protocol version
|
||
|
|
#[serde(default = "default_protocol_version")]
|
||
|
|
pub protocol_version: String,
|
||
|
|
}
|
||
|
|
|
||
|
|
/// DAG execution configuration
|
||
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
|
|
pub struct DagConfig {
|
||
|
|
/// Maximum concurrent tasks
|
||
|
|
#[serde(default = "default_max_concurrent_tasks")]
|
||
|
|
pub max_concurrent_tasks: usize,
|
||
|
|
|
||
|
|
/// Task timeout (milliseconds)
|
||
|
|
#[serde(default = "default_task_timeout")]
|
||
|
|
pub task_timeout: u64,
|
||
|
|
|
||
|
|
/// Number of retry attempts
|
||
|
|
#[serde(default = "default_dag_retry_attempts")]
|
||
|
|
pub retry_attempts: u32,
|
||
|
|
|
||
|
|
/// Delay between retries (milliseconds)
|
||
|
|
#[serde(default = "default_retry_delay")]
|
||
|
|
pub retry_delay: u64,
|
||
|
|
|
||
|
|
/// Task queue size
|
||
|
|
#[serde(default = "default_queue_size")]
|
||
|
|
pub queue_size: usize,
|
||
|
|
}
|
||
|
|
|
||
|
|
// Default value functions
|
||
|
|
fn default_host() -> String {
|
||
|
|
"127.0.0.1".to_string()
|
||
|
|
}
|
||
|
|
|
||
|
|
fn default_server_port() -> u16 {
|
||
|
|
8082
|
||
|
|
}
|
||
|
|
|
||
|
|
fn default_workers() -> usize {
|
||
|
|
4
|
||
|
|
}
|
||
|
|
|
||
|
|
fn default_keep_alive() -> u64 {
|
||
|
|
75
|
||
|
|
}
|
||
|
|
|
||
|
|
fn default_request_timeout() -> u64 {
|
||
|
|
30000
|
||
|
|
}
|
||
|
|
|
||
|
|
fn default_rag_url() -> String {
|
||
|
|
"http://localhost:8083".to_string()
|
||
|
|
}
|
||
|
|
|
||
|
|
fn default_rag_timeout() -> u64 {
|
||
|
|
30000
|
||
|
|
}
|
||
|
|
|
||
|
|
fn default_mcp_url() -> String {
|
||
|
|
"http://localhost:8084".to_string()
|
||
|
|
}
|
||
|
|
|
||
|
|
fn default_mcp_timeout() -> u64 {
|
||
|
|
30000
|
||
|
|
}
|
||
|
|
|
||
|
|
fn default_max_retries() -> u32 {
|
||
|
|
3
|
||
|
|
}
|
||
|
|
|
||
|
|
fn default_true() -> bool {
|
||
|
|
true
|
||
|
|
}
|
||
|
|
|
||
|
|
fn default_protocol_version() -> String {
|
||
|
|
"1.0".to_string()
|
||
|
|
}
|
||
|
|
|
||
|
|
fn default_max_concurrent_tasks() -> usize {
|
||
|
|
10
|
||
|
|
}
|
||
|
|
|
||
|
|
fn default_task_timeout() -> u64 {
|
||
|
|
600000
|
||
|
|
}
|
||
|
|
|
||
|
|
fn default_dag_retry_attempts() -> u32 {
|
||
|
|
3
|
||
|
|
}
|
||
|
|
|
||
|
|
fn default_retry_delay() -> u64 {
|
||
|
|
1000
|
||
|
|
}
|
||
|
|
|
||
|
|
fn default_queue_size() -> usize {
|
||
|
|
1000
|
||
|
|
}
|
||
|
|
|
||
|
|
impl Default for ServerConfig {
|
||
|
|
fn default() -> Self {
|
||
|
|
Self {
|
||
|
|
host: default_host(),
|
||
|
|
port: default_server_port(),
|
||
|
|
workers: default_workers(),
|
||
|
|
keep_alive: default_keep_alive(),
|
||
|
|
request_timeout: default_request_timeout(),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
impl Default for RagIntegrationConfig {
|
||
|
|
fn default() -> Self {
|
||
|
|
Self {
|
||
|
|
enabled: false,
|
||
|
|
rag_service_url: default_rag_url(),
|
||
|
|
timeout: default_rag_timeout(),
|
||
|
|
max_retries: default_max_retries(),
|
||
|
|
cache_enabled: default_true(),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
impl Default for McpIntegrationConfig {
|
||
|
|
fn default() -> Self {
|
||
|
|
Self {
|
||
|
|
enabled: false,
|
||
|
|
mcp_service_url: default_mcp_url(),
|
||
|
|
timeout: default_mcp_timeout(),
|
||
|
|
max_retries: default_max_retries(),
|
||
|
|
protocol_version: default_protocol_version(),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
impl Default for DagConfig {
|
||
|
|
fn default() -> Self {
|
||
|
|
Self {
|
||
|
|
max_concurrent_tasks: default_max_concurrent_tasks(),
|
||
|
|
task_timeout: default_task_timeout(),
|
||
|
|
retry_attempts: default_dag_retry_attempts(),
|
||
|
|
retry_delay: default_retry_delay(),
|
||
|
|
queue_size: default_queue_size(),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
impl ConfigLoader for AiServiceConfig {
|
||
|
|
fn service_name() -> &'static str {
|
||
|
|
"ai-service"
|
||
|
|
}
|
||
|
|
|
||
|
|
fn load_from_hierarchy() -> std::result::Result<Self, Box<dyn std::error::Error + Send + Sync>>
|
||
|
|
{
|
||
|
|
let service = Self::service_name();
|
||
|
|
|
||
|
|
if let Some(path) = platform_config::resolve_config_path(service) {
|
||
|
|
return Self::from_path(&path);
|
||
|
|
}
|
||
|
|
|
||
|
|
// Fallback to defaults
|
||
|
|
Ok(Self::default())
|
||
|
|
}
|
||
|
|
|
||
|
|
fn apply_env_overrides(
|
||
|
|
&mut self,
|
||
|
|
) -> std::result::Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||
|
|
Self::apply_env_overrides_internal(self);
|
||
|
|
Ok(())
|
||
|
|
}
|
||
|
|
|
||
|
|
fn from_path<P: AsRef<Path>>(
|
||
|
|
path: P,
|
||
|
|
) -> std::result::Result<Self, Box<dyn std::error::Error + Send + Sync>> {
|
||
|
|
let path = path.as_ref();
|
||
|
|
let json_value = platform_config::format::load_config(path).map_err(|e| {
|
||
|
|
let err: Box<dyn std::error::Error + Send + Sync> = Box::new(e);
|
||
|
|
err
|
||
|
|
})?;
|
||
|
|
|
||
|
|
serde_json::from_value(json_value).map_err(|e| {
|
||
|
|
let err_msg = format!(
|
||
|
|
"Failed to deserialize AI service config from {:?}: {}",
|
||
|
|
path, e
|
||
|
|
);
|
||
|
|
Box::new(std::io::Error::new(
|
||
|
|
std::io::ErrorKind::InvalidData,
|
||
|
|
err_msg,
|
||
|
|
)) as Box<dyn std::error::Error + Send + Sync>
|
||
|
|
})
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
impl AiServiceConfig {
|
||
|
|
/// Load configuration from hierarchical sources with mode support
|
||
|
|
///
|
||
|
|
/// Priority order:
|
||
|
|
/// 1. AI_SERVICE_CONFIG environment variable (explicit path)
|
||
|
|
/// 2. AI_SERVICE_MODE environment variable (mode-specific file)
|
||
|
|
/// 3. Default configuration
|
||
|
|
///
|
||
|
|
/// After loading, applies environment variable overrides.
|
||
|
|
pub fn load_from_hierarchy() -> Result<Self, Box<dyn std::error::Error>> {
|
||
|
|
<Self as ConfigLoader>::load_from_hierarchy().map_err(|_e| {
|
||
|
|
Box::new(std::io::Error::other("Failed to load AI service config"))
|
||
|
|
as Box<dyn std::error::Error>
|
||
|
|
})
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Internal: Apply environment variable overrides (mutable reference)
|
||
|
|
///
|
||
|
|
/// Overrides take precedence over loaded config values.
|
||
|
|
/// Pattern: AI_SERVICE_{SECTION}_{KEY}
|
||
|
|
fn apply_env_overrides_internal(config: &mut Self) {
|
||
|
|
// Server overrides
|
||
|
|
if let Ok(val) = env::var("AI_SERVICE_SERVER_HOST") {
|
||
|
|
config.server.host = val;
|
||
|
|
}
|
||
|
|
if let Ok(val) = env::var("AI_SERVICE_SERVER_PORT") {
|
||
|
|
if let Ok(port) = val.parse() {
|
||
|
|
config.server.port = port;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
if let Ok(val) = env::var("AI_SERVICE_SERVER_WORKERS") {
|
||
|
|
if let Ok(workers) = val.parse() {
|
||
|
|
config.server.workers = workers;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// RAG integration overrides
|
||
|
|
if let Ok(val) = env::var("AI_SERVICE_RAG_ENABLED") {
|
||
|
|
config.rag.enabled = val.parse().unwrap_or(config.rag.enabled);
|
||
|
|
}
|
||
|
|
if let Ok(val) = env::var("AI_SERVICE_RAG_URL") {
|
||
|
|
config.rag.rag_service_url = val;
|
||
|
|
}
|
||
|
|
if let Ok(val) = env::var("AI_SERVICE_RAG_TIMEOUT") {
|
||
|
|
if let Ok(timeout) = val.parse() {
|
||
|
|
config.rag.timeout = timeout;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// MCP integration overrides
|
||
|
|
if let Ok(val) = env::var("AI_SERVICE_MCP_ENABLED") {
|
||
|
|
config.mcp.enabled = val.parse().unwrap_or(config.mcp.enabled);
|
||
|
|
}
|
||
|
|
if let Ok(val) = env::var("AI_SERVICE_MCP_URL") {
|
||
|
|
config.mcp.mcp_service_url = val;
|
||
|
|
}
|
||
|
|
if let Ok(val) = env::var("AI_SERVICE_MCP_TIMEOUT") {
|
||
|
|
if let Ok(timeout) = val.parse() {
|
||
|
|
config.mcp.timeout = timeout;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// DAG overrides
|
||
|
|
if let Ok(val) = env::var("AI_SERVICE_DAG_MAX_CONCURRENT_TASKS") {
|
||
|
|
if let Ok(tasks) = val.parse() {
|
||
|
|
config.dag.max_concurrent_tasks = tasks;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
if let Ok(val) = env::var("AI_SERVICE_DAG_TASK_TIMEOUT") {
|
||
|
|
if let Ok(timeout) = val.parse() {
|
||
|
|
config.dag.task_timeout = timeout;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
if let Ok(val) = env::var("AI_SERVICE_DAG_RETRY_ATTEMPTS") {
|
||
|
|
if let Ok(retries) = val.parse() {
|
||
|
|
config.dag.retry_attempts = retries;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
#[cfg(test)]
|
||
|
|
mod tests {
|
||
|
|
use super::*;
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_default_config() {
|
||
|
|
let config = AiServiceConfig::default();
|
||
|
|
assert_eq!(config.server.port, 8082);
|
||
|
|
assert_eq!(config.server.workers, 4);
|
||
|
|
assert!(!config.rag.enabled);
|
||
|
|
assert!(!config.mcp.enabled);
|
||
|
|
assert_eq!(config.dag.max_concurrent_tasks, 10);
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_server_config_defaults() {
|
||
|
|
let server = ServerConfig::default();
|
||
|
|
assert_eq!(server.host, "127.0.0.1");
|
||
|
|
assert_eq!(server.port, 8082);
|
||
|
|
assert_eq!(server.workers, 4);
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_dag_config_defaults() {
|
||
|
|
let dag = DagConfig::default();
|
||
|
|
assert_eq!(dag.max_concurrent_tasks, 10);
|
||
|
|
assert_eq!(dag.task_timeout, 600000);
|
||
|
|
assert_eq!(dag.retry_attempts, 3);
|
||
|
|
}
|
||
|
|
}
|