//! Authentication and authorization for MCP server //! //! Provides token-based authentication mechanism for MCP requests. //! Tokens can be provided via: //! 1. Environment variable: `KOGRAL_MCP_TOKEN` //! 2. JSON-RPC request parameter: `token` field in params use anyhow::{anyhow, Result}; use serde_json::Value; use tracing::{debug, warn}; /// Authentication configuration pub struct AuthConfig { /// Expected token for authentication (from environment or config) token: Option, /// Whether authentication is required (true if token is set) required: bool, } impl AuthConfig { /// Create authentication config from environment and settings pub fn from_env() -> Self { let token = std::env::var("KOGRAL_MCP_TOKEN").ok(); let required = token.is_some(); Self { token, required } } /// Verify a request token against the configured token /// /// # Arguments /// * `request_token` - Token provided in the request (can be from params or header) /// /// # Returns /// * Ok(()) if authentication succeeds /// * Err if authentication fails pub fn verify(&self, request_token: Option<&str>) -> Result<()> { if !self.required { // Authentication not required return Ok(()); } let configured_token = match &self.token { Some(t) => t, None => return Err(anyhow!("Authentication required but no token configured")), }; let provided_token = match request_token { Some(t) => t, None => { warn!("Authentication required but no token provided"); return Err(anyhow!("Missing authentication token")); } }; // Constant-time comparison to prevent timing attacks if constant_time_compare(provided_token, configured_token) { debug!("Authentication successful"); Ok(()) } else { warn!("Authentication failed: invalid token"); Err(anyhow!("Invalid authentication token")) } } /// Whether authentication is required for this server pub fn is_required(&self) -> bool { self.required } } /// Extract token from JSON-RPC request parameters /// /// Supports multiple formats: /// 1. `{"token": "..."}` - Direct token field /// 2. `{"params": {"token": "..."}...}` - Token in params object pub fn extract_token_from_params(params: &Value) -> Option<&str> { // Try direct token field if let Some(token) = params.get("token").and_then(|v| v.as_str()) { return Some(token); } // Try nested in object if let Some(obj) = params.as_object() { if let Some(token_value) = obj.get("token") { return token_value.as_str(); } } None } /// Constant-time string comparison to prevent timing attacks fn constant_time_compare(a: &str, b: &str) -> bool { if a.len() != b.len() { return false; } let mut result = 0u8; for (a_byte, b_byte) in a.bytes().zip(b.bytes()) { result |= a_byte ^ b_byte; } result == 0 } #[cfg(test)] mod tests { use super::*; #[test] fn test_auth_config_no_env_token() { // When no token is set in environment std::env::remove_var("KOGRAL_MCP_TOKEN"); let config = AuthConfig::from_env(); assert!(!config.is_required()); assert!(config.verify(None).is_ok()); assert!(config.verify(Some("any-token")).is_ok()); } #[test] fn test_auth_config_with_env_token() { // When token is set in environment std::env::set_var("KOGRAL_MCP_TOKEN", "secret-token"); let config = AuthConfig::from_env(); assert!(config.is_required()); // Missing token should fail assert!(config.verify(None).is_err()); // Wrong token should fail assert!(config.verify(Some("wrong-token")).is_err()); // Correct token should succeed assert!(config.verify(Some("secret-token")).is_ok()); } #[test] fn test_extract_token_from_direct_field() { let params = serde_json::json!({"token": "my-token", "other": "field"}); assert_eq!(extract_token_from_params(¶ms), Some("my-token")); } #[test] fn test_extract_token_missing() { let params = serde_json::json!({"other": "field"}); assert_eq!(extract_token_from_params(¶ms), None); } #[test] fn test_constant_time_compare_equal() { assert!(constant_time_compare("secret", "secret")); } #[test] fn test_constant_time_compare_different() { assert!(!constant_time_compare("secret", "wrong")); } #[test] fn test_constant_time_compare_different_length() { assert!(!constant_time_compare("short", "much-longer-string")); } }