// vapora-llm-router: Configuration module // Load and parse LLM router configuration from TOML use std::collections::HashMap; use std::path::Path; use serde::{Deserialize, Serialize}; use thiserror::Error; #[derive(Debug, Error)] pub enum ConfigError { #[error("Failed to read config file: {0}")] ReadError(#[from] std::io::Error), #[error("Failed to parse TOML: {0}")] ParseError(#[from] toml::de::Error), #[error("Invalid configuration: {0}")] ValidationError(String), } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct LLMRouterConfig { pub routing: RoutingConfig, pub providers: HashMap, #[serde(default)] pub routing_rules: Vec, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct RoutingConfig { pub default_provider: String, #[serde(default = "default_true")] pub cost_tracking_enabled: bool, #[serde(default = "default_true")] pub fallback_enabled: bool, } fn default_true() -> bool { true } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ProviderConfig { #[serde(default = "default_true")] pub enabled: bool, pub api_key: Option, pub url: Option, pub model: String, #[serde(default = "default_max_tokens")] pub max_tokens: usize, #[serde(default = "default_temperature")] pub temperature: f32, #[serde(default)] pub cost_per_1m_input: f64, #[serde(default)] pub cost_per_1m_output: f64, } fn default_max_tokens() -> usize { 4096 } fn default_temperature() -> f32 { 0.7 } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct RoutingRule { pub name: String, pub condition: HashMap, pub provider: String, pub model_override: Option, } impl LLMRouterConfig { /// Load configuration from TOML file pub fn load>(path: P) -> Result { let content = std::fs::read_to_string(path)?; let mut config: Self = toml::from_str(&content)?; // Expand environment variables in API keys and URLs config.expand_env_vars(); config.validate()?; Ok(config) } /// Expand environment variables in configuration fn expand_env_vars(&mut self) { for (_, provider) in self.providers.iter_mut() { if let Some(ref api_key) = provider.api_key { provider.api_key = Some(expand_env_var(api_key)); } if let Some(ref url) = provider.url { provider.url = Some(expand_env_var(url)); } } } /// Validate configuration fn validate(&self) -> Result<(), ConfigError> { // Check that default provider exists if !self.providers.contains_key(&self.routing.default_provider) { return Err(ConfigError::ValidationError(format!( "Default provider '{}' not found in providers", self.routing.default_provider ))); } // Check that all routing rules reference valid providers for rule in &self.routing_rules { if !self.providers.contains_key(&rule.provider) { return Err(ConfigError::ValidationError(format!( "Routing rule '{}' references unknown provider '{}'", rule.name, rule.provider ))); } } Ok(()) } /// Get provider configuration by name pub fn get_provider(&self, name: &str) -> Option<&ProviderConfig> { self.providers.get(name) } /// Find routing rule matching conditions pub fn find_rule(&self, conditions: &HashMap) -> Option<&RoutingRule> { self.routing_rules.iter().find(|rule| { rule.condition .iter() .all(|(key, value)| conditions.get(key).map(|v| v == value).unwrap_or(false)) }) } } /// Expand environment variables in format ${VAR} or ${VAR:-default} fn expand_env_var(input: &str) -> String { if !input.starts_with("${") || !input.ends_with('}') { return input.to_string(); } let var_part = &input[2..input.len() - 1]; // Handle ${VAR:-default} format if let Some(pos) = var_part.find(":-") { let var_name = &var_part[..pos]; let default_value = &var_part[pos + 2..]; std::env::var(var_name).unwrap_or_else(|_| default_value.to_string()) } else { // Handle ${VAR} format std::env::var(var_part).unwrap_or_default() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_expand_env_var() { std::env::set_var("TEST_VAR", "test_value"); assert_eq!(expand_env_var("${TEST_VAR}"), "test_value"); assert_eq!(expand_env_var("plain_text"), "plain_text"); assert_eq!(expand_env_var("${NONEXISTENT:-default}"), "default"); } #[test] fn test_config_validation() { let config = LLMRouterConfig { routing: RoutingConfig { default_provider: "claude".to_string(), cost_tracking_enabled: true, fallback_enabled: true, }, providers: { let mut map = HashMap::new(); map.insert( "claude".to_string(), ProviderConfig { enabled: true, api_key: Some("test".to_string()), url: None, model: "claude-sonnet-4".to_string(), max_tokens: 4096, temperature: 0.7, cost_per_1m_input: 3.0, cost_per_1m_output: 15.0, }, ); map }, routing_rules: vec![], }; assert!(config.validate().is_ok()); } #[test] fn test_invalid_default_provider() { let config = LLMRouterConfig { routing: RoutingConfig { default_provider: "nonexistent".to_string(), cost_tracking_enabled: true, fallback_enabled: true, }, providers: HashMap::new(), routing_rules: vec![], }; assert!(config.validate().is_err()); } }