Jesús Pérez 18d9d41c34
Some checks failed
Rust CI / Security Audit (push) Has been cancelled
Rust CI / Check + Test + Lint (nightly) (push) Has been cancelled
Rust CI / Check + Test + Lint (stable) (push) Has been cancelled
chore: update creates
2026-02-03 22:04:51 +00:00

575 lines
19 KiB
Rust

use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use base64::engine::general_purpose::STANDARD as BASE64;
use base64::Engine as _;
use serde_json::{json, Value};
use super::Engine;
use crate::core::SealMechanism;
use crate::crypto::{CryptoBackend, KeyAlgorithm, SymmetricAlgorithm};
use crate::error::{Result, VaultError};
use crate::storage::StorageBackend;
/// Encrypted key version
#[derive(Debug, Clone)]
struct TransitKey {
name: String,
versions: HashMap<u64, KeyVersion>,
current_version: u64,
min_decrypt_version: u64,
}
/// Individual key version
#[derive(Debug, Clone)]
struct KeyVersion {
/// Key algorithm (AES-256-GCM for symmetric, ML-KEM-768 for PQC)
algorithm: KeyAlgorithm,
/// For symmetric: AES key material (32 bytes)
/// For ML-KEM-768: serialized keypair (public + private)
key_material: Vec<u8>,
#[allow(dead_code)]
created_at: chrono::DateTime<chrono::Utc>,
}
/// Transit key algorithm types
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum TransitKeyAlgorithm {
/// AES-256-GCM symmetric encryption (legacy)
Aes256Gcm,
/// ML-KEM-768 post-quantum key wrapping
#[cfg(feature = "pqc")]
MlKem768,
}
/// Transit secrets engine for encryption/decryption
pub struct TransitEngine {
storage: Arc<dyn StorageBackend>,
crypto: Arc<dyn CryptoBackend>,
seal: Arc<tokio::sync::Mutex<SealMechanism>>,
#[allow(dead_code)]
mount_path: String,
keys: Arc<tokio::sync::Mutex<HashMap<String, TransitKey>>>,
}
impl TransitEngine {
/// Create a new Transit engine instance
pub fn new(
storage: Arc<dyn StorageBackend>,
crypto: Arc<dyn CryptoBackend>,
seal: Arc<tokio::sync::Mutex<SealMechanism>>,
mount_path: String,
) -> Self {
Self {
storage,
crypto,
seal,
mount_path,
keys: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
}
}
/// Get storage key for transit key
#[allow(dead_code)]
fn storage_key(&self, key_name: &str) -> String {
format!("{}keys/{}", self.mount_path, key_name)
}
/// Create or update a transit key (symmetric AES-256-GCM)
pub async fn create_key(&self, key_name: &str, key_material: Vec<u8>) -> Result<()> {
self.create_key_with_algorithm(key_name, key_material, TransitKeyAlgorithm::Aes256Gcm)
.await
}
/// Create or update a transit key with specific algorithm
pub async fn create_key_with_algorithm(
&self,
key_name: &str,
key_material: Vec<u8>,
algorithm: TransitKeyAlgorithm,
) -> Result<()> {
let now = chrono::Utc::now();
let mut keys = self.keys.lock().await;
let key_algorithm = match algorithm {
TransitKeyAlgorithm::Aes256Gcm => KeyAlgorithm::Rsa2048, // Placeholder
#[cfg(feature = "pqc")]
TransitKeyAlgorithm::MlKem768 => KeyAlgorithm::MlKem768,
};
if let Some(key) = keys.get_mut(key_name) {
// Existing key - increment version
let next_version = key.current_version + 1;
key.versions.insert(
next_version,
KeyVersion {
algorithm: key_algorithm,
key_material,
created_at: now,
},
);
key.current_version = next_version;
} else {
// New key - create with version 1
let mut key = TransitKey {
name: key_name.to_string(),
versions: HashMap::new(),
current_version: 1,
min_decrypt_version: 1,
};
key.versions.insert(
1,
KeyVersion {
algorithm: key_algorithm,
key_material,
created_at: now,
},
);
keys.insert(key_name.to_string(), key);
}
Ok(())
}
/// Create ML-KEM-768 transit key for post-quantum encryption
#[cfg(feature = "pqc")]
pub async fn create_pqc_key(&self, key_name: &str) -> Result<()> {
// Generate ML-KEM-768 keypair
let keypair = self
.crypto
.generate_keypair(KeyAlgorithm::MlKem768)
.await
.map_err(|e| VaultError::crypto(e.to_string()))?;
// Serialize keypair (public + private concatenated)
let mut key_material = Vec::new();
key_material.extend_from_slice(&keypair.public_key.key_data);
key_material.extend_from_slice(&keypair.private_key.key_data);
self.create_key_with_algorithm(key_name, key_material, TransitKeyAlgorithm::MlKem768)
.await
}
/// Encrypt plaintext using the specified key
pub async fn encrypt(&self, key_name: &str, plaintext: &[u8]) -> Result<String> {
let keys = self.keys.lock().await;
let key = keys
.get(key_name)
.ok_or_else(|| VaultError::storage(format!("Key not found: {}", key_name)))?;
let key_version = key
.versions
.get(&key.current_version)
.ok_or_else(|| VaultError::crypto("Key version not found".to_string()))?;
let key_material = key_version.key_material.clone();
#[cfg(feature = "pqc")]
let key_algorithm = key_version.algorithm;
let current_version = key.current_version;
drop(keys);
#[cfg(feature = "pqc")]
if key_algorithm == KeyAlgorithm::MlKem768 {
// ML-KEM-768 key wrapping
// Parse keypair from serialized format
if key_material.len() < 1184 {
return Err(VaultError::crypto(
"Invalid ML-KEM-768 key material".to_string(),
));
}
let public_key_data = &key_material[..1184];
let public_key = crate::crypto::PublicKey {
algorithm: KeyAlgorithm::MlKem768,
key_data: public_key_data.to_vec(),
};
// KEM encapsulation to get shared secret
let (kem_ct, shared_secret) = self
.crypto
.kem_encapsulate(&public_key)
.await
.map_err(|e| VaultError::crypto(format!("KEM encapsulation failed: {}", e)))?;
// Encrypt plaintext with shared secret as AES key
let aes_ct = self
.crypto
.encrypt_symmetric(&shared_secret, plaintext, SymmetricAlgorithm::Aes256Gcm)
.await
.map_err(|e| VaultError::crypto(e.to_string()))?;
// Wire format: [kem_ct_len:4][kem_ct][aes_ct]
let mut combined = Vec::with_capacity(4 + kem_ct.len() + aes_ct.len());
combined.extend_from_slice(&(kem_ct.len() as u32).to_be_bytes());
combined.extend_from_slice(&kem_ct);
combined.extend_from_slice(&aes_ct);
// Format: vault:v{version}:base64_encoded_ciphertext
let encoded = BASE64.encode(&combined);
return Ok(format!("vault:v{}:{}", current_version, encoded));
}
// AES-256-GCM symmetric encryption (legacy path)
let ciphertext = self
.crypto
.encrypt_symmetric(&key_material, plaintext, SymmetricAlgorithm::Aes256Gcm)
.await
.map_err(|e| VaultError::crypto(e.to_string()))?;
// Format: vault:v{version}:base64_encoded_ciphertext
let encoded = BASE64.encode(&ciphertext);
Ok(format!("vault:v{}:{}", current_version, encoded))
}
/// Decrypt ciphertext using the appropriate key version
pub async fn decrypt(&self, key_name: &str, ciphertext_str: &str) -> Result<Vec<u8>> {
// Parse vault format: vault:v{version}:base64_data
let parts: Vec<&str> = ciphertext_str.split(':').collect();
if parts.len() != 3 || parts[0] != "vault" {
return Err(VaultError::crypto(
"Invalid vault ciphertext format".to_string(),
));
}
let version_str = parts[1]
.strip_prefix('v')
.ok_or_else(|| VaultError::crypto("Invalid version format".to_string()))?;
let version: u64 = version_str
.parse()
.map_err(|e| VaultError::crypto(format!("Failed to parse version: {}", e)))?;
let ciphertext = BASE64
.decode(parts[2])
.map_err(|e| VaultError::crypto(format!("Failed to decode ciphertext: {}", e)))?;
let keys = self.keys.lock().await;
let key = keys
.get(key_name)
.ok_or_else(|| VaultError::storage(format!("Key not found: {}", key_name)))?;
if version < key.min_decrypt_version {
return Err(VaultError::crypto(format!(
"Key version {} is below minimum decrypt version {}",
version, key.min_decrypt_version
)));
}
let key_version = key
.versions
.get(&version)
.ok_or_else(|| VaultError::crypto(format!("Key version {} not found", version)))?;
let key_material = key_version.key_material.clone();
let _key_algorithm = key_version.algorithm;
drop(keys);
#[cfg(feature = "pqc")]
if _key_algorithm == KeyAlgorithm::MlKem768 {
// ML-KEM-768 key unwrapping
// Parse wire format: [kem_ct_len:4][kem_ct][aes_ct]
if ciphertext.len() < 4 {
return Err(VaultError::crypto(
"Invalid KEM ciphertext format".to_string(),
));
}
let kem_ct_len =
u32::from_be_bytes([ciphertext[0], ciphertext[1], ciphertext[2], ciphertext[3]])
as usize;
if ciphertext.len() < 4 + kem_ct_len {
return Err(VaultError::crypto("Truncated KEM ciphertext".to_string()));
}
let kem_ct = &ciphertext[4..4 + kem_ct_len];
let aes_ct = &ciphertext[4 + kem_ct_len..];
// Parse keypair from serialized format
if key_material.len() < 1184 + 2400 {
return Err(VaultError::crypto(
"Invalid ML-KEM-768 key material".to_string(),
));
}
let private_key_data = &key_material[1184..1184 + 2400];
let private_key = crate::crypto::PrivateKey {
algorithm: KeyAlgorithm::MlKem768,
key_data: private_key_data.to_vec(),
};
// KEM decapsulation to get shared secret
let shared_secret = self
.crypto
.kem_decapsulate(&private_key, kem_ct)
.await
.map_err(|e| VaultError::crypto(format!("KEM decapsulation failed: {}", e)))?;
// Decrypt AES ciphertext with shared secret
let plaintext = self
.crypto
.decrypt_symmetric(&shared_secret, aes_ct, SymmetricAlgorithm::Aes256Gcm)
.await
.map_err(|e| VaultError::crypto(e.to_string()))?;
return Ok(plaintext);
}
// AES-256-GCM symmetric decryption (legacy path)
self.crypto
.decrypt_symmetric(&key_material, &ciphertext, SymmetricAlgorithm::Aes256Gcm)
.await
.map_err(|e| VaultError::crypto(e.to_string()))
}
/// Rewrap ciphertext under the current key version
pub async fn rewrap(&self, key_name: &str, ciphertext_str: &str) -> Result<String> {
let plaintext = self.decrypt(key_name, ciphertext_str).await?;
self.encrypt(key_name, &plaintext).await
}
}
#[async_trait]
impl Engine for TransitEngine {
fn name(&self) -> &str {
"transit"
}
fn engine_type(&self) -> &str {
"transit"
}
async fn read(&self, path: &str) -> Result<Option<Value>> {
if let Some(key_name) = path.strip_prefix("keys/") {
let keys = self.keys.lock().await;
if let Some(key) = keys.get(key_name) {
let key_version = key.versions.get(&key.current_version);
let mut response = json!({
"name": key.name,
"current_version": key.current_version,
"min_decrypt_version": key.min_decrypt_version,
});
// Add public key and creation timestamp for PQC keys
if let Some(kv) = key_version {
response["algorithm"] = json!(kv.algorithm.as_str());
response["created_at"] = json!(kv.created_at.to_rfc3339());
// For ML-KEM-768, extract and base64 encode the public key
#[cfg(feature = "pqc")]
if kv.algorithm == KeyAlgorithm::MlKem768 && kv.key_material.len() >= 1184 {
let public_key_data = &kv.key_material[..1184];
response["public_key"] = json!(BASE64.encode(public_key_data));
}
}
return Ok(Some(response));
}
}
Ok(None)
}
async fn write(&self, path: &str, data: &Value) -> Result<()> {
if let Some(key_name) = path.strip_prefix("encrypt/") {
let plaintext = data
.get("plaintext")
.and_then(|v| v.as_str())
.ok_or_else(|| VaultError::storage("Missing 'plaintext' in request".to_string()))?;
let _ciphertext = self.encrypt(key_name, plaintext.as_bytes()).await?;
// Note: In a full implementation, this would return the ciphertext
// in the response
} else if let Some(key_name) = path.strip_prefix("decrypt/") {
let ciphertext = data
.get("ciphertext")
.and_then(|v| v.as_str())
.ok_or_else(|| {
VaultError::storage("Missing 'ciphertext' in request".to_string())
})?;
let _plaintext = self.decrypt(key_name, ciphertext).await?;
// Note: In a full implementation, this would return the plaintext
// in the response
} else if let Some(key_name) = path.strip_prefix("rewrap/") {
let ciphertext = data
.get("ciphertext")
.and_then(|v| v.as_str())
.ok_or_else(|| {
VaultError::storage("Missing 'ciphertext' in request".to_string())
})?;
let _new_ciphertext = self.rewrap(key_name, ciphertext).await?;
// Note: In a full implementation, this would return the new
// ciphertext in the response
} else if let Some(rest) = path.strip_prefix("pqc-keys/") {
if rest.ends_with("/generate") {
let _key_name = rest.trim_end_matches("/generate");
#[cfg(feature = "pqc")]
self.create_pqc_key(_key_name).await?;
}
}
Ok(())
}
async fn delete(&self, path: &str) -> Result<()> {
if let Some(key_name) = path.strip_prefix("keys/") {
let mut keys = self.keys.lock().await;
keys.remove(key_name);
}
Ok(())
}
async fn list(&self, prefix: &str) -> Result<Vec<String>> {
let keys = self.keys.lock().await;
let mut result = Vec::new();
for key_name in keys.keys() {
if key_name.starts_with(prefix) {
result.push(key_name.clone());
}
}
Ok(result)
}
async fn health_check(&self) -> Result<()> {
self.storage
.health_check()
.await
.map_err(|e| VaultError::storage(e.to_string()))?;
let seal = self.seal.lock().await;
if seal.is_sealed() {
return Err(VaultError::crypto("Vault is sealed".to_string()));
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use tempfile::TempDir;
use super::*;
use crate::config::{FilesystemStorageConfig, SealConfig, ShamirSealConfig, StorageConfig};
use crate::crypto::CryptoRegistry;
use crate::storage::StorageRegistry;
async fn setup_engine() -> Result<(TransitEngine, TempDir)> {
let temp_dir = TempDir::new().map_err(|e| VaultError::storage(e.to_string()))?;
let fs_config = FilesystemStorageConfig {
path: temp_dir.path().to_path_buf(),
};
let storage_config = StorageConfig {
backend: "filesystem".to_string(),
filesystem: fs_config,
surrealdb: Default::default(),
etcd: Default::default(),
postgresql: Default::default(),
};
let storage = StorageRegistry::create(&storage_config).await?;
let crypto = CryptoRegistry::create("openssl", &Default::default())?;
let seal_config = SealConfig {
seal_type: "shamir".to_string(),
shamir: ShamirSealConfig {
threshold: 2,
shares: 3,
},
auto_unseal: Default::default(),
};
let mut seal = crate::core::SealMechanism::new(&seal_config)?;
// Initialize and unseal for testing
let _init_result = seal.init(crypto.as_ref(), storage.as_ref()).await?;
let seal_arc = Arc::new(tokio::sync::Mutex::new(seal));
let engine = TransitEngine::new(storage, crypto.clone(), seal_arc, "transit/".to_string());
Ok((engine, temp_dir))
}
#[allow(dead_code)]
fn mock_key_name() -> String {
"my-key".to_string()
}
#[tokio::test]
async fn test_transit_encrypt_decrypt() -> Result<()> {
let (engine, _temp) = setup_engine().await?;
let plaintext = b"sensitive data";
engine.create_key("my-key", vec![0x42; 32]).await?;
let ciphertext = engine.encrypt("my-key", plaintext).await?;
assert!(ciphertext.starts_with("vault:v"));
let decrypted = engine.decrypt("my-key", &ciphertext).await?;
assert_eq!(decrypted, plaintext);
Ok(())
}
#[tokio::test]
async fn test_transit_key_rotation() -> Result<()> {
let (engine, _temp) = setup_engine().await?;
engine.create_key("my-key", vec![0x11; 32]).await?;
let ct1 = engine.encrypt("my-key", b"data v1").await?;
// Rotate key
engine.create_key("my-key", vec![0x22; 32]).await?;
let ct2 = engine.encrypt("my-key", b"data v2").await?;
// Should use different versions
assert!(ct1.contains(":v1:"));
assert!(ct2.contains(":v2:"));
Ok(())
}
#[tokio::test]
async fn test_transit_rewrap() -> Result<()> {
let (engine, _temp) = setup_engine().await?;
engine.create_key("my-key", vec![0x42; 32]).await?;
let ct1 = engine.encrypt("my-key", b"test data").await?;
// Rotate and rewrap
engine.create_key("my-key", vec![0x99; 32]).await?;
let ct2 = engine.rewrap("my-key", &ct1).await?;
// Rewrapped should use new version
assert!(ct2.contains(":v2:"));
Ok(())
}
#[tokio::test]
async fn test_transit_invalid_ciphertext() -> Result<()> {
let (engine, _temp) = setup_engine().await?;
engine.create_key("my-key", vec![0x42; 32]).await?;
let result = engine.decrypt("my-key", "invalid:format").await;
assert!(result.is_err());
Ok(())
}
#[tokio::test]
async fn test_transit_health_check() -> Result<()> {
let (engine, _temp) = setup_engine().await?;
engine.health_check().await?;
Ok(())
}
}