165 lines
4.8 KiB
Rust
165 lines
4.8 KiB
Rust
|
|
//! Authentication and authorization for MCP server
|
||
|
|
//!
|
||
|
|
//! Provides token-based authentication mechanism for MCP requests.
|
||
|
|
//! Tokens can be provided via:
|
||
|
|
//! 1. Environment variable: `KOGRAL_MCP_TOKEN`
|
||
|
|
//! 2. JSON-RPC request parameter: `token` field in params
|
||
|
|
|
||
|
|
use anyhow::{anyhow, Result};
|
||
|
|
use serde_json::Value;
|
||
|
|
use tracing::{debug, warn};
|
||
|
|
|
||
|
|
/// Authentication configuration
|
||
|
|
pub struct AuthConfig {
|
||
|
|
/// Expected token for authentication (from environment or config)
|
||
|
|
token: Option<String>,
|
||
|
|
/// Whether authentication is required (true if token is set)
|
||
|
|
required: bool,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl AuthConfig {
|
||
|
|
/// Create authentication config from environment and settings
|
||
|
|
pub fn from_env() -> Self {
|
||
|
|
let token = std::env::var("KOGRAL_MCP_TOKEN").ok();
|
||
|
|
let required = token.is_some();
|
||
|
|
|
||
|
|
Self { token, required }
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Verify a request token against the configured token
|
||
|
|
///
|
||
|
|
/// # Arguments
|
||
|
|
/// * `request_token` - Token provided in the request (can be from params or header)
|
||
|
|
///
|
||
|
|
/// # Returns
|
||
|
|
/// * Ok(()) if authentication succeeds
|
||
|
|
/// * Err if authentication fails
|
||
|
|
pub fn verify(&self, request_token: Option<&str>) -> Result<()> {
|
||
|
|
if !self.required {
|
||
|
|
// Authentication not required
|
||
|
|
return Ok(());
|
||
|
|
}
|
||
|
|
|
||
|
|
let configured_token = match &self.token {
|
||
|
|
Some(t) => t,
|
||
|
|
None => return Err(anyhow!("Authentication required but no token configured")),
|
||
|
|
};
|
||
|
|
|
||
|
|
let provided_token = match request_token {
|
||
|
|
Some(t) => t,
|
||
|
|
None => {
|
||
|
|
warn!("Authentication required but no token provided");
|
||
|
|
return Err(anyhow!("Missing authentication token"));
|
||
|
|
}
|
||
|
|
};
|
||
|
|
|
||
|
|
// Constant-time comparison to prevent timing attacks
|
||
|
|
if constant_time_compare(provided_token, configured_token) {
|
||
|
|
debug!("Authentication successful");
|
||
|
|
Ok(())
|
||
|
|
} else {
|
||
|
|
warn!("Authentication failed: invalid token");
|
||
|
|
Err(anyhow!("Invalid authentication token"))
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Whether authentication is required for this server
|
||
|
|
pub fn is_required(&self) -> bool {
|
||
|
|
self.required
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Extract token from JSON-RPC request parameters
|
||
|
|
///
|
||
|
|
/// Supports multiple formats:
|
||
|
|
/// 1. `{"token": "..."}` - Direct token field
|
||
|
|
/// 2. `{"params": {"token": "..."}...}` - Token in params object
|
||
|
|
pub fn extract_token_from_params(params: &Value) -> Option<&str> {
|
||
|
|
// Try direct token field
|
||
|
|
if let Some(token) = params.get("token").and_then(|v| v.as_str()) {
|
||
|
|
return Some(token);
|
||
|
|
}
|
||
|
|
|
||
|
|
// Try nested in object
|
||
|
|
if let Some(obj) = params.as_object() {
|
||
|
|
if let Some(token_value) = obj.get("token") {
|
||
|
|
return token_value.as_str();
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
None
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Constant-time string comparison to prevent timing attacks
|
||
|
|
fn constant_time_compare(a: &str, b: &str) -> bool {
|
||
|
|
if a.len() != b.len() {
|
||
|
|
return false;
|
||
|
|
}
|
||
|
|
|
||
|
|
let mut result = 0u8;
|
||
|
|
for (a_byte, b_byte) in a.bytes().zip(b.bytes()) {
|
||
|
|
result |= a_byte ^ b_byte;
|
||
|
|
}
|
||
|
|
|
||
|
|
result == 0
|
||
|
|
}
|
||
|
|
|
||
|
|
#[cfg(test)]
|
||
|
|
mod tests {
|
||
|
|
use super::*;
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_auth_config_no_env_token() {
|
||
|
|
// When no token is set in environment
|
||
|
|
std::env::remove_var("KOGRAL_MCP_TOKEN");
|
||
|
|
let config = AuthConfig::from_env();
|
||
|
|
assert!(!config.is_required());
|
||
|
|
assert!(config.verify(None).is_ok());
|
||
|
|
assert!(config.verify(Some("any-token")).is_ok());
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_auth_config_with_env_token() {
|
||
|
|
// When token is set in environment
|
||
|
|
std::env::set_var("KOGRAL_MCP_TOKEN", "secret-token");
|
||
|
|
let config = AuthConfig::from_env();
|
||
|
|
assert!(config.is_required());
|
||
|
|
|
||
|
|
// Missing token should fail
|
||
|
|
assert!(config.verify(None).is_err());
|
||
|
|
|
||
|
|
// Wrong token should fail
|
||
|
|
assert!(config.verify(Some("wrong-token")).is_err());
|
||
|
|
|
||
|
|
// Correct token should succeed
|
||
|
|
assert!(config.verify(Some("secret-token")).is_ok());
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_extract_token_from_direct_field() {
|
||
|
|
let params = serde_json::json!({"token": "my-token", "other": "field"});
|
||
|
|
assert_eq!(extract_token_from_params(¶ms), Some("my-token"));
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_extract_token_missing() {
|
||
|
|
let params = serde_json::json!({"other": "field"});
|
||
|
|
assert_eq!(extract_token_from_params(¶ms), None);
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_constant_time_compare_equal() {
|
||
|
|
assert!(constant_time_compare("secret", "secret"));
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_constant_time_compare_different() {
|
||
|
|
assert!(!constant_time_compare("secret", "wrong"));
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_constant_time_compare_different_length() {
|
||
|
|
assert!(!constant_time_compare("short", "much-longer-string"));
|
||
|
|
}
|
||
|
|
}
|