use std::collections::HashMap; use std::sync::Arc; use async_trait::async_trait; use base64::engine::general_purpose::STANDARD as BASE64; use base64::Engine as _; use serde_json::{json, Value}; use super::Engine; use crate::core::SealMechanism; use crate::crypto::{CryptoBackend, KeyAlgorithm, SymmetricAlgorithm}; use crate::error::{Result, VaultError}; use crate::storage::StorageBackend; /// Encrypted key version #[derive(Debug, Clone)] struct TransitKey { name: String, versions: HashMap, current_version: u64, min_decrypt_version: u64, } /// Individual key version #[derive(Debug, Clone)] struct KeyVersion { /// Key algorithm (AES-256-GCM for symmetric, ML-KEM-768 for PQC) algorithm: KeyAlgorithm, /// For symmetric: AES key material (32 bytes) /// For ML-KEM-768: serialized keypair (public + private) key_material: Vec, #[allow(dead_code)] created_at: chrono::DateTime, } /// Transit key algorithm types #[derive(Debug, Clone, Copy, PartialEq)] pub enum TransitKeyAlgorithm { /// AES-256-GCM symmetric encryption (legacy) Aes256Gcm, /// ML-KEM-768 post-quantum key wrapping #[cfg(feature = "pqc")] MlKem768, } /// Transit secrets engine for encryption/decryption pub struct TransitEngine { storage: Arc, crypto: Arc, seal: Arc>, #[allow(dead_code)] mount_path: String, keys: Arc>>, } impl TransitEngine { /// Create a new Transit engine instance pub fn new( storage: Arc, crypto: Arc, seal: Arc>, mount_path: String, ) -> Self { Self { storage, crypto, seal, mount_path, keys: Arc::new(tokio::sync::Mutex::new(HashMap::new())), } } /// Get storage key for transit key #[allow(dead_code)] fn storage_key(&self, key_name: &str) -> String { format!("{}keys/{}", self.mount_path, key_name) } /// Create or update a transit key (symmetric AES-256-GCM) pub async fn create_key(&self, key_name: &str, key_material: Vec) -> Result<()> { self.create_key_with_algorithm(key_name, key_material, TransitKeyAlgorithm::Aes256Gcm) .await } /// Create or update a transit key with specific algorithm pub async fn create_key_with_algorithm( &self, key_name: &str, key_material: Vec, algorithm: TransitKeyAlgorithm, ) -> Result<()> { let now = chrono::Utc::now(); let mut keys = self.keys.lock().await; let key_algorithm = match algorithm { TransitKeyAlgorithm::Aes256Gcm => KeyAlgorithm::Rsa2048, // Placeholder #[cfg(feature = "pqc")] TransitKeyAlgorithm::MlKem768 => KeyAlgorithm::MlKem768, }; if let Some(key) = keys.get_mut(key_name) { // Existing key - increment version let next_version = key.current_version + 1; key.versions.insert( next_version, KeyVersion { algorithm: key_algorithm, key_material, created_at: now, }, ); key.current_version = next_version; } else { // New key - create with version 1 let mut key = TransitKey { name: key_name.to_string(), versions: HashMap::new(), current_version: 1, min_decrypt_version: 1, }; key.versions.insert( 1, KeyVersion { algorithm: key_algorithm, key_material, created_at: now, }, ); keys.insert(key_name.to_string(), key); } Ok(()) } /// Create ML-KEM-768 transit key for post-quantum encryption #[cfg(feature = "pqc")] pub async fn create_pqc_key(&self, key_name: &str) -> Result<()> { // Generate ML-KEM-768 keypair let keypair = self .crypto .generate_keypair(KeyAlgorithm::MlKem768) .await .map_err(|e| VaultError::crypto(e.to_string()))?; // Serialize keypair (public + private concatenated) let mut key_material = Vec::new(); key_material.extend_from_slice(&keypair.public_key.key_data); key_material.extend_from_slice(&keypair.private_key.key_data); self.create_key_with_algorithm(key_name, key_material, TransitKeyAlgorithm::MlKem768) .await } /// Encrypt plaintext using the specified key pub async fn encrypt(&self, key_name: &str, plaintext: &[u8]) -> Result { let keys = self.keys.lock().await; let key = keys .get(key_name) .ok_or_else(|| VaultError::storage(format!("Key not found: {}", key_name)))?; let key_version = key .versions .get(&key.current_version) .ok_or_else(|| VaultError::crypto("Key version not found".to_string()))?; let key_material = key_version.key_material.clone(); #[cfg(feature = "pqc")] let key_algorithm = key_version.algorithm; let current_version = key.current_version; drop(keys); #[cfg(feature = "pqc")] if key_algorithm == KeyAlgorithm::MlKem768 { // ML-KEM-768 key wrapping // Parse keypair from serialized format if key_material.len() < 1184 { return Err(VaultError::crypto( "Invalid ML-KEM-768 key material".to_string(), )); } let public_key_data = &key_material[..1184]; let public_key = crate::crypto::PublicKey { algorithm: KeyAlgorithm::MlKem768, key_data: public_key_data.to_vec(), }; // KEM encapsulation to get shared secret let (kem_ct, shared_secret) = self .crypto .kem_encapsulate(&public_key) .await .map_err(|e| VaultError::crypto(format!("KEM encapsulation failed: {}", e)))?; // Encrypt plaintext with shared secret as AES key let aes_ct = self .crypto .encrypt_symmetric(&shared_secret, plaintext, SymmetricAlgorithm::Aes256Gcm) .await .map_err(|e| VaultError::crypto(e.to_string()))?; // Wire format: [kem_ct_len:4][kem_ct][aes_ct] let mut combined = Vec::with_capacity(4 + kem_ct.len() + aes_ct.len()); combined.extend_from_slice(&(kem_ct.len() as u32).to_be_bytes()); combined.extend_from_slice(&kem_ct); combined.extend_from_slice(&aes_ct); // Format: vault:v{version}:base64_encoded_ciphertext let encoded = BASE64.encode(&combined); return Ok(format!("vault:v{}:{}", current_version, encoded)); } // AES-256-GCM symmetric encryption (legacy path) let ciphertext = self .crypto .encrypt_symmetric(&key_material, plaintext, SymmetricAlgorithm::Aes256Gcm) .await .map_err(|e| VaultError::crypto(e.to_string()))?; // Format: vault:v{version}:base64_encoded_ciphertext let encoded = BASE64.encode(&ciphertext); Ok(format!("vault:v{}:{}", current_version, encoded)) } /// Decrypt ciphertext using the appropriate key version pub async fn decrypt(&self, key_name: &str, ciphertext_str: &str) -> Result> { // Parse vault format: vault:v{version}:base64_data let parts: Vec<&str> = ciphertext_str.split(':').collect(); if parts.len() != 3 || parts[0] != "vault" { return Err(VaultError::crypto( "Invalid vault ciphertext format".to_string(), )); } let version_str = parts[1] .strip_prefix('v') .ok_or_else(|| VaultError::crypto("Invalid version format".to_string()))?; let version: u64 = version_str .parse() .map_err(|e| VaultError::crypto(format!("Failed to parse version: {}", e)))?; let ciphertext = BASE64 .decode(parts[2]) .map_err(|e| VaultError::crypto(format!("Failed to decode ciphertext: {}", e)))?; let keys = self.keys.lock().await; let key = keys .get(key_name) .ok_or_else(|| VaultError::storage(format!("Key not found: {}", key_name)))?; if version < key.min_decrypt_version { return Err(VaultError::crypto(format!( "Key version {} is below minimum decrypt version {}", version, key.min_decrypt_version ))); } let key_version = key .versions .get(&version) .ok_or_else(|| VaultError::crypto(format!("Key version {} not found", version)))?; let key_material = key_version.key_material.clone(); let _key_algorithm = key_version.algorithm; drop(keys); #[cfg(feature = "pqc")] if _key_algorithm == KeyAlgorithm::MlKem768 { // ML-KEM-768 key unwrapping // Parse wire format: [kem_ct_len:4][kem_ct][aes_ct] if ciphertext.len() < 4 { return Err(VaultError::crypto( "Invalid KEM ciphertext format".to_string(), )); } let kem_ct_len = u32::from_be_bytes([ciphertext[0], ciphertext[1], ciphertext[2], ciphertext[3]]) as usize; if ciphertext.len() < 4 + kem_ct_len { return Err(VaultError::crypto("Truncated KEM ciphertext".to_string())); } let kem_ct = &ciphertext[4..4 + kem_ct_len]; let aes_ct = &ciphertext[4 + kem_ct_len..]; // Parse keypair from serialized format if key_material.len() < 1184 + 2400 { return Err(VaultError::crypto( "Invalid ML-KEM-768 key material".to_string(), )); } let private_key_data = &key_material[1184..1184 + 2400]; let private_key = crate::crypto::PrivateKey { algorithm: KeyAlgorithm::MlKem768, key_data: private_key_data.to_vec(), }; // KEM decapsulation to get shared secret let shared_secret = self .crypto .kem_decapsulate(&private_key, kem_ct) .await .map_err(|e| VaultError::crypto(format!("KEM decapsulation failed: {}", e)))?; // Decrypt AES ciphertext with shared secret let plaintext = self .crypto .decrypt_symmetric(&shared_secret, aes_ct, SymmetricAlgorithm::Aes256Gcm) .await .map_err(|e| VaultError::crypto(e.to_string()))?; return Ok(plaintext); } // AES-256-GCM symmetric decryption (legacy path) self.crypto .decrypt_symmetric(&key_material, &ciphertext, SymmetricAlgorithm::Aes256Gcm) .await .map_err(|e| VaultError::crypto(e.to_string())) } /// Rewrap ciphertext under the current key version pub async fn rewrap(&self, key_name: &str, ciphertext_str: &str) -> Result { let plaintext = self.decrypt(key_name, ciphertext_str).await?; self.encrypt(key_name, &plaintext).await } } #[async_trait] impl Engine for TransitEngine { fn name(&self) -> &str { "transit" } fn engine_type(&self) -> &str { "transit" } async fn read(&self, path: &str) -> Result> { if let Some(key_name) = path.strip_prefix("keys/") { let keys = self.keys.lock().await; if let Some(key) = keys.get(key_name) { let key_version = key.versions.get(&key.current_version); let mut response = json!({ "name": key.name, "current_version": key.current_version, "min_decrypt_version": key.min_decrypt_version, }); // Add public key and creation timestamp for PQC keys if let Some(kv) = key_version { response["algorithm"] = json!(kv.algorithm.as_str()); response["created_at"] = json!(kv.created_at.to_rfc3339()); // For ML-KEM-768, extract and base64 encode the public key #[cfg(feature = "pqc")] if kv.algorithm == KeyAlgorithm::MlKem768 && kv.key_material.len() >= 1184 { let public_key_data = &kv.key_material[..1184]; response["public_key"] = json!(BASE64.encode(public_key_data)); } } return Ok(Some(response)); } } Ok(None) } async fn write(&self, path: &str, data: &Value) -> Result<()> { if let Some(key_name) = path.strip_prefix("encrypt/") { let plaintext = data .get("plaintext") .and_then(|v| v.as_str()) .ok_or_else(|| VaultError::storage("Missing 'plaintext' in request".to_string()))?; let _ciphertext = self.encrypt(key_name, plaintext.as_bytes()).await?; // Note: In a full implementation, this would return the ciphertext // in the response } else if let Some(key_name) = path.strip_prefix("decrypt/") { let ciphertext = data .get("ciphertext") .and_then(|v| v.as_str()) .ok_or_else(|| { VaultError::storage("Missing 'ciphertext' in request".to_string()) })?; let _plaintext = self.decrypt(key_name, ciphertext).await?; // Note: In a full implementation, this would return the plaintext // in the response } else if let Some(key_name) = path.strip_prefix("rewrap/") { let ciphertext = data .get("ciphertext") .and_then(|v| v.as_str()) .ok_or_else(|| { VaultError::storage("Missing 'ciphertext' in request".to_string()) })?; let _new_ciphertext = self.rewrap(key_name, ciphertext).await?; // Note: In a full implementation, this would return the new // ciphertext in the response } else if let Some(rest) = path.strip_prefix("pqc-keys/") { if rest.ends_with("/generate") { let _key_name = rest.trim_end_matches("/generate"); #[cfg(feature = "pqc")] self.create_pqc_key(_key_name).await?; } } Ok(()) } async fn delete(&self, path: &str) -> Result<()> { if let Some(key_name) = path.strip_prefix("keys/") { let mut keys = self.keys.lock().await; keys.remove(key_name); } Ok(()) } async fn list(&self, prefix: &str) -> Result> { let keys = self.keys.lock().await; let mut result = Vec::new(); for key_name in keys.keys() { if key_name.starts_with(prefix) { result.push(key_name.clone()); } } Ok(result) } async fn health_check(&self) -> Result<()> { self.storage .health_check() .await .map_err(|e| VaultError::storage(e.to_string()))?; let seal = self.seal.lock().await; if seal.is_sealed() { return Err(VaultError::crypto("Vault is sealed".to_string())); } Ok(()) } } #[cfg(test)] mod tests { use tempfile::TempDir; use super::*; use crate::config::{FilesystemStorageConfig, SealConfig, ShamirSealConfig, StorageConfig}; use crate::crypto::CryptoRegistry; use crate::storage::StorageRegistry; async fn setup_engine() -> Result<(TransitEngine, TempDir)> { let temp_dir = TempDir::new().map_err(|e| VaultError::storage(e.to_string()))?; let fs_config = FilesystemStorageConfig { path: temp_dir.path().to_path_buf(), }; let storage_config = StorageConfig { backend: "filesystem".to_string(), filesystem: fs_config, surrealdb: Default::default(), etcd: Default::default(), postgresql: Default::default(), }; let storage = StorageRegistry::create(&storage_config).await?; let crypto = CryptoRegistry::create("openssl", &Default::default())?; let seal_config = SealConfig { seal_type: "shamir".to_string(), shamir: ShamirSealConfig { threshold: 2, shares: 3, }, auto_unseal: Default::default(), }; let mut seal = crate::core::SealMechanism::new(&seal_config)?; // Initialize and unseal for testing let _init_result = seal.init(crypto.as_ref(), storage.as_ref()).await?; let seal_arc = Arc::new(tokio::sync::Mutex::new(seal)); let engine = TransitEngine::new(storage, crypto.clone(), seal_arc, "transit/".to_string()); Ok((engine, temp_dir)) } #[allow(dead_code)] fn mock_key_name() -> String { "my-key".to_string() } #[tokio::test] async fn test_transit_encrypt_decrypt() -> Result<()> { let (engine, _temp) = setup_engine().await?; let plaintext = b"sensitive data"; engine.create_key("my-key", vec![0x42; 32]).await?; let ciphertext = engine.encrypt("my-key", plaintext).await?; assert!(ciphertext.starts_with("vault:v")); let decrypted = engine.decrypt("my-key", &ciphertext).await?; assert_eq!(decrypted, plaintext); Ok(()) } #[tokio::test] async fn test_transit_key_rotation() -> Result<()> { let (engine, _temp) = setup_engine().await?; engine.create_key("my-key", vec![0x11; 32]).await?; let ct1 = engine.encrypt("my-key", b"data v1").await?; // Rotate key engine.create_key("my-key", vec![0x22; 32]).await?; let ct2 = engine.encrypt("my-key", b"data v2").await?; // Should use different versions assert!(ct1.contains(":v1:")); assert!(ct2.contains(":v2:")); Ok(()) } #[tokio::test] async fn test_transit_rewrap() -> Result<()> { let (engine, _temp) = setup_engine().await?; engine.create_key("my-key", vec![0x42; 32]).await?; let ct1 = engine.encrypt("my-key", b"test data").await?; // Rotate and rewrap engine.create_key("my-key", vec![0x99; 32]).await?; let ct2 = engine.rewrap("my-key", &ct1).await?; // Rewrapped should use new version assert!(ct2.contains(":v2:")); Ok(()) } #[tokio::test] async fn test_transit_invalid_ciphertext() -> Result<()> { let (engine, _temp) = setup_engine().await?; engine.create_key("my-key", vec![0x42; 32]).await?; let result = engine.decrypt("my-key", "invalid:format").await; assert!(result.is_err()); Ok(()) } #[tokio::test] async fn test_transit_health_check() -> Result<()> { let (engine, _temp) = setup_engine().await?; engine.health_check().await?; Ok(()) } }