406 lines
13 KiB
Rust
Raw Normal View History

use std::collections::HashMap;
use std::sync::Arc;
2025-12-22 21:34:01 +00:00
use async_trait::async_trait;
use base64::engine::general_purpose::STANDARD as BASE64;
use base64::Engine as _;
use serde_json::{json, Value};
use super::Engine;
use crate::core::SealMechanism;
use crate::crypto::{CryptoBackend, SymmetricAlgorithm};
use crate::error::{Result, VaultError};
use crate::storage::StorageBackend;
/// Encrypted key version
#[derive(Debug, Clone)]
struct TransitKey {
name: String,
versions: HashMap<u64, KeyVersion>,
current_version: u64,
min_decrypt_version: u64,
}
/// Individual key version
#[derive(Debug, Clone)]
struct KeyVersion {
key_material: Vec<u8>,
#[allow(dead_code)]
created_at: chrono::DateTime<chrono::Utc>,
}
/// Transit secrets engine for encryption/decryption
pub struct TransitEngine {
storage: Arc<dyn StorageBackend>,
crypto: Arc<dyn CryptoBackend>,
seal: Arc<tokio::sync::Mutex<SealMechanism>>,
#[allow(dead_code)]
mount_path: String,
keys: Arc<tokio::sync::Mutex<HashMap<String, TransitKey>>>,
}
impl TransitEngine {
/// Create a new Transit engine instance
pub fn new(
storage: Arc<dyn StorageBackend>,
crypto: Arc<dyn CryptoBackend>,
seal: Arc<tokio::sync::Mutex<SealMechanism>>,
mount_path: String,
) -> Self {
Self {
storage,
crypto,
seal,
mount_path,
keys: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
}
}
/// Get storage key for transit key
#[allow(dead_code)]
fn storage_key(&self, key_name: &str) -> String {
format!("{}keys/{}", self.mount_path, key_name)
}
/// Create or update a transit key
pub async fn create_key(&self, key_name: &str, key_material: Vec<u8>) -> Result<()> {
let now = chrono::Utc::now();
let mut keys = self.keys.lock().await;
if let Some(key) = keys.get_mut(key_name) {
// Existing key - increment version
let next_version = key.current_version + 1;
key.versions.insert(
next_version,
KeyVersion {
key_material,
created_at: now,
},
);
key.current_version = next_version;
} else {
// New key - create with version 1
let mut key = TransitKey {
name: key_name.to_string(),
versions: HashMap::new(),
current_version: 1,
min_decrypt_version: 1,
};
key.versions.insert(
1,
KeyVersion {
key_material,
created_at: now,
},
);
keys.insert(key_name.to_string(), key);
}
Ok(())
}
/// Encrypt plaintext using the specified key
pub async fn encrypt(&self, key_name: &str, plaintext: &[u8]) -> Result<String> {
let keys = self.keys.lock().await;
let key = keys
.get(key_name)
.ok_or_else(|| VaultError::storage(format!("Key not found: {}", key_name)))?;
let key_version = key
.versions
.get(&key.current_version)
.ok_or_else(|| VaultError::crypto("Key version not found".to_string()))?;
let key_material = key_version.key_material.clone();
let current_version = key.current_version;
drop(keys);
// Encrypt plaintext using the current key version (lock is dropped before
// await)
2025-12-22 21:34:01 +00:00
let ciphertext = self
.crypto
.encrypt_symmetric(&key_material, plaintext, SymmetricAlgorithm::Aes256Gcm)
.await
.map_err(|e| VaultError::crypto(e.to_string()))?;
// Format: vault:v{version}:base64_encoded_ciphertext
let encoded = BASE64.encode(&ciphertext);
Ok(format!("vault:v{}:{}", current_version, encoded))
}
/// Decrypt ciphertext using the appropriate key version
pub async fn decrypt(&self, key_name: &str, ciphertext_str: &str) -> Result<Vec<u8>> {
// Parse vault format: vault:v{version}:base64_data
let parts: Vec<&str> = ciphertext_str.split(':').collect();
if parts.len() != 3 || parts[0] != "vault" {
return Err(VaultError::crypto(
"Invalid vault ciphertext format".to_string(),
));
}
let version_str = parts[1]
.strip_prefix('v')
.ok_or_else(|| VaultError::crypto("Invalid version format".to_string()))?;
let version: u64 = version_str
.parse()
.map_err(|e| VaultError::crypto(format!("Failed to parse version: {}", e)))?;
let ciphertext = BASE64
.decode(parts[2])
.map_err(|e| VaultError::crypto(format!("Failed to decode ciphertext: {}", e)))?;
let keys = self.keys.lock().await;
let key = keys
.get(key_name)
.ok_or_else(|| VaultError::storage(format!("Key not found: {}", key_name)))?;
if version < key.min_decrypt_version {
return Err(VaultError::crypto(format!(
"Key version {} is below minimum decrypt version {}",
version, key.min_decrypt_version
)));
}
let key_version = key
.versions
.get(&version)
.ok_or_else(|| VaultError::crypto(format!("Key version {} not found", version)))?;
let key_material = key_version.key_material.clone();
drop(keys);
self.crypto
.decrypt_symmetric(&key_material, &ciphertext, SymmetricAlgorithm::Aes256Gcm)
.await
.map_err(|e| VaultError::crypto(e.to_string()))
}
/// Rewrap ciphertext under the current key version
pub async fn rewrap(&self, key_name: &str, ciphertext_str: &str) -> Result<String> {
let plaintext = self.decrypt(key_name, ciphertext_str).await?;
self.encrypt(key_name, &plaintext).await
}
}
#[async_trait]
impl Engine for TransitEngine {
fn name(&self) -> &str {
"transit"
}
fn engine_type(&self) -> &str {
"transit"
}
async fn read(&self, path: &str) -> Result<Option<Value>> {
if let Some(key_name) = path.strip_prefix("keys/") {
let keys = self.keys.lock().await;
if let Some(key) = keys.get(key_name) {
return Ok(Some(json!({
"name": key.name,
"current_version": key.current_version,
"min_decrypt_version": key.min_decrypt_version,
})));
}
}
Ok(None)
}
async fn write(&self, path: &str, data: &Value) -> Result<()> {
if let Some(key_name) = path.strip_prefix("encrypt/") {
let plaintext = data
.get("plaintext")
.and_then(|v| v.as_str())
.ok_or_else(|| VaultError::storage("Missing 'plaintext' in request".to_string()))?;
let _ciphertext = self.encrypt(key_name, plaintext.as_bytes()).await?;
// Note: In a full implementation, this would return the ciphertext
// in the response
2025-12-22 21:34:01 +00:00
} else if let Some(key_name) = path.strip_prefix("decrypt/") {
let ciphertext = data
.get("ciphertext")
.and_then(|v| v.as_str())
.ok_or_else(|| {
VaultError::storage("Missing 'ciphertext' in request".to_string())
})?;
let _plaintext = self.decrypt(key_name, ciphertext).await?;
// Note: In a full implementation, this would return the plaintext
// in the response
2025-12-22 21:34:01 +00:00
} else if let Some(key_name) = path.strip_prefix("rewrap/") {
let ciphertext = data
.get("ciphertext")
.and_then(|v| v.as_str())
.ok_or_else(|| {
VaultError::storage("Missing 'ciphertext' in request".to_string())
})?;
let _new_ciphertext = self.rewrap(key_name, ciphertext).await?;
// Note: In a full implementation, this would return the new
// ciphertext in the response
2025-12-22 21:34:01 +00:00
}
Ok(())
}
async fn delete(&self, path: &str) -> Result<()> {
if let Some(key_name) = path.strip_prefix("keys/") {
let mut keys = self.keys.lock().await;
keys.remove(key_name);
}
Ok(())
}
async fn list(&self, prefix: &str) -> Result<Vec<String>> {
let keys = self.keys.lock().await;
let mut result = Vec::new();
for key_name in keys.keys() {
if key_name.starts_with(prefix) {
result.push(key_name.clone());
}
}
Ok(result)
}
async fn health_check(&self) -> Result<()> {
self.storage
.health_check()
.await
.map_err(|e| VaultError::storage(e.to_string()))?;
let seal = self.seal.lock().await;
if seal.is_sealed() {
return Err(VaultError::crypto("Vault is sealed".to_string()));
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use tempfile::TempDir;
2025-12-22 21:34:01 +00:00
use super::*;
use crate::config::{FilesystemStorageConfig, SealConfig, ShamirSealConfig, StorageConfig};
use crate::crypto::CryptoRegistry;
use crate::storage::StorageRegistry;
async fn setup_engine() -> Result<(TransitEngine, TempDir)> {
let temp_dir = TempDir::new().map_err(|e| VaultError::storage(e.to_string()))?;
let fs_config = FilesystemStorageConfig {
path: temp_dir.path().to_path_buf(),
};
let storage_config = StorageConfig {
backend: "filesystem".to_string(),
filesystem: fs_config,
surrealdb: Default::default(),
etcd: Default::default(),
postgresql: Default::default(),
};
let storage = StorageRegistry::create(&storage_config).await?;
let crypto = CryptoRegistry::create("openssl", &Default::default())?;
let seal_config = SealConfig {
seal_type: "shamir".to_string(),
shamir: ShamirSealConfig {
threshold: 2,
shares: 3,
},
auto_unseal: Default::default(),
};
let mut seal = crate::core::SealMechanism::new(&seal_config)?;
// Initialize and unseal for testing
let _init_result = seal.init(crypto.as_ref(), storage.as_ref()).await?;
let seal_arc = Arc::new(tokio::sync::Mutex::new(seal));
let engine = TransitEngine::new(storage, crypto.clone(), seal_arc, "transit/".to_string());
Ok((engine, temp_dir))
}
#[allow(dead_code)]
fn mock_key_name() -> String {
"my-key".to_string()
}
#[tokio::test]
async fn test_transit_encrypt_decrypt() -> Result<()> {
let (engine, _temp) = setup_engine().await?;
let plaintext = b"sensitive data";
engine.create_key("my-key", vec![0x42; 32]).await?;
let ciphertext = engine.encrypt("my-key", plaintext).await?;
assert!(ciphertext.starts_with("vault:v"));
let decrypted = engine.decrypt("my-key", &ciphertext).await?;
assert_eq!(decrypted, plaintext);
Ok(())
}
#[tokio::test]
async fn test_transit_key_rotation() -> Result<()> {
let (engine, _temp) = setup_engine().await?;
engine.create_key("my-key", vec![0x11; 32]).await?;
let ct1 = engine.encrypt("my-key", b"data v1").await?;
// Rotate key
engine.create_key("my-key", vec![0x22; 32]).await?;
let ct2 = engine.encrypt("my-key", b"data v2").await?;
// Should use different versions
assert!(ct1.contains(":v1:"));
assert!(ct2.contains(":v2:"));
Ok(())
}
#[tokio::test]
async fn test_transit_rewrap() -> Result<()> {
let (engine, _temp) = setup_engine().await?;
engine.create_key("my-key", vec![0x42; 32]).await?;
let ct1 = engine.encrypt("my-key", b"test data").await?;
// Rotate and rewrap
engine.create_key("my-key", vec![0x99; 32]).await?;
let ct2 = engine.rewrap("my-key", &ct1).await?;
// Rewrapped should use new version
assert!(ct2.contains(":v2:"));
Ok(())
}
#[tokio::test]
async fn test_transit_invalid_ciphertext() -> Result<()> {
let (engine, _temp) = setup_engine().await?;
engine.create_key("my-key", vec![0x42; 32]).await?;
let result = engine.decrypt("my-key", "invalid:format").await;
assert!(result.is_err());
Ok(())
}
#[tokio::test]
async fn test_transit_health_check() -> Result<()> {
let (engine, _temp) = setup_engine().await?;
engine.health_check().await?;
Ok(())
}
}