165 lines
4.8 KiB
Rust
Raw Normal View History

2026-01-23 16:13:23 +00:00
//! Authentication and authorization for MCP server
//!
//! Provides token-based authentication mechanism for MCP requests.
//! Tokens can be provided via:
//! 1. Environment variable: `KOGRAL_MCP_TOKEN`
//! 2. JSON-RPC request parameter: `token` field in params
use anyhow::{anyhow, Result};
use serde_json::Value;
use tracing::{debug, warn};
/// Authentication configuration
pub struct AuthConfig {
/// Expected token for authentication (from environment or config)
token: Option<String>,
/// Whether authentication is required (true if token is set)
required: bool,
}
impl AuthConfig {
/// Create authentication config from environment and settings
pub fn from_env() -> Self {
let token = std::env::var("KOGRAL_MCP_TOKEN").ok();
let required = token.is_some();
Self { token, required }
}
/// Verify a request token against the configured token
///
/// # Arguments
/// * `request_token` - Token provided in the request (can be from params or header)
///
/// # Returns
/// * Ok(()) if authentication succeeds
/// * Err if authentication fails
pub fn verify(&self, request_token: Option<&str>) -> Result<()> {
if !self.required {
// Authentication not required
return Ok(());
}
let configured_token = match &self.token {
Some(t) => t,
None => return Err(anyhow!("Authentication required but no token configured")),
};
let provided_token = match request_token {
Some(t) => t,
None => {
warn!("Authentication required but no token provided");
return Err(anyhow!("Missing authentication token"));
}
};
// Constant-time comparison to prevent timing attacks
if constant_time_compare(provided_token, configured_token) {
debug!("Authentication successful");
Ok(())
} else {
warn!("Authentication failed: invalid token");
Err(anyhow!("Invalid authentication token"))
}
}
/// Whether authentication is required for this server
pub fn is_required(&self) -> bool {
self.required
}
}
/// Extract token from JSON-RPC request parameters
///
/// Supports multiple formats:
/// 1. `{"token": "..."}` - Direct token field
/// 2. `{"params": {"token": "..."}...}` - Token in params object
pub fn extract_token_from_params(params: &Value) -> Option<&str> {
// Try direct token field
if let Some(token) = params.get("token").and_then(|v| v.as_str()) {
return Some(token);
}
// Try nested in object
if let Some(obj) = params.as_object() {
if let Some(token_value) = obj.get("token") {
return token_value.as_str();
}
}
None
}
/// Constant-time string comparison to prevent timing attacks
fn constant_time_compare(a: &str, b: &str) -> bool {
if a.len() != b.len() {
return false;
}
let mut result = 0u8;
for (a_byte, b_byte) in a.bytes().zip(b.bytes()) {
result |= a_byte ^ b_byte;
}
result == 0
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_auth_config_no_env_token() {
// When no token is set in environment
std::env::remove_var("KOGRAL_MCP_TOKEN");
let config = AuthConfig::from_env();
assert!(!config.is_required());
assert!(config.verify(None).is_ok());
assert!(config.verify(Some("any-token")).is_ok());
}
#[test]
fn test_auth_config_with_env_token() {
// When token is set in environment
std::env::set_var("KOGRAL_MCP_TOKEN", "secret-token");
let config = AuthConfig::from_env();
assert!(config.is_required());
// Missing token should fail
assert!(config.verify(None).is_err());
// Wrong token should fail
assert!(config.verify(Some("wrong-token")).is_err());
// Correct token should succeed
assert!(config.verify(Some("secret-token")).is_ok());
}
#[test]
fn test_extract_token_from_direct_field() {
let params = serde_json::json!({"token": "my-token", "other": "field"});
assert_eq!(extract_token_from_params(&params), Some("my-token"));
}
#[test]
fn test_extract_token_missing() {
let params = serde_json::json!({"other": "field"});
assert_eq!(extract_token_from_params(&params), None);
}
#[test]
fn test_constant_time_compare_equal() {
assert!(constant_time_compare("secret", "secret"));
}
#[test]
fn test_constant_time_compare_different() {
assert!(!constant_time_compare("secret", "wrong"));
}
#[test]
fn test_constant_time_compare_different_length() {
assert!(!constant_time_compare("short", "much-longer-string"));
}
}