secretumvault/src/storage/postgresql.rs
2025-12-22 21:34:01 +00:00

227 lines
7.6 KiB
Rust

//! PostgreSQL storage backend for SecretumVault
//!
//! Provides persistent secret storage using PostgreSQL as the backend.
//! This implementation uses an in-memory store (production would use sqlx + real DB).
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use crate::config::PostgreSQLStorageConfig;
use crate::error::{StorageError, StorageResult};
use crate::storage::{EncryptedData, Lease, StorageBackend, StoredKey, StoredPolicy};
/// PostgreSQL storage backend for secrets persistence
pub struct PostgreSQLBackend {
// In-memory storage (production would use actual PostgreSQL)
secrets: Arc<RwLock<HashMap<String, Vec<u8>>>>,
keys: Arc<RwLock<HashMap<String, Vec<u8>>>>,
policies: Arc<RwLock<HashMap<String, Vec<u8>>>>,
leases: Arc<RwLock<HashMap<String, Vec<u8>>>>,
connection_string: String,
}
impl PostgreSQLBackend {
/// Create a new PostgreSQL backend instance
pub fn new(config: &PostgreSQLStorageConfig) -> std::result::Result<Self, StorageError> {
if !config.connection_string.starts_with("postgres://") {
return Err(StorageError::Internal(
"Invalid PostgreSQL connection string".to_string(),
));
}
Ok(Self {
secrets: Arc::new(RwLock::new(HashMap::new())),
keys: Arc::new(RwLock::new(HashMap::new())),
policies: Arc::new(RwLock::new(HashMap::new())),
leases: Arc::new(RwLock::new(HashMap::new())),
connection_string: config.connection_string.clone(),
})
}
}
impl std::fmt::Debug for PostgreSQLBackend {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PostgreSQLBackend")
.field("connection_string", &self.connection_string)
.finish()
}
}
#[async_trait]
impl StorageBackend for PostgreSQLBackend {
async fn store_secret(&self, path: &str, data: &EncryptedData) -> StorageResult<()> {
let serialized =
serde_json::to_vec(&data).map_err(|e| StorageError::Serialization(e.to_string()))?;
let mut secrets = self.secrets.write().await;
secrets.insert(path.to_string(), serialized);
Ok(())
}
async fn get_secret(&self, path: &str) -> StorageResult<EncryptedData> {
let secrets = self.secrets.read().await;
match secrets.get(path) {
Some(data) => {
serde_json::from_slice(data).map_err(|e| StorageError::Serialization(e.to_string()))
}
None => Err(StorageError::NotFound(path.to_string())),
}
}
async fn delete_secret(&self, path: &str) -> StorageResult<()> {
let mut secrets = self.secrets.write().await;
secrets.remove(path);
Ok(())
}
async fn list_secrets(&self, prefix: &str) -> StorageResult<Vec<String>> {
let secrets = self.secrets.read().await;
let results: Vec<String> = secrets
.keys()
.filter(|k| k.starts_with(prefix))
.cloned()
.collect();
Ok(results)
}
async fn store_key(&self, key: &StoredKey) -> StorageResult<()> {
let serialized =
serde_json::to_vec(&key).map_err(|e| StorageError::Serialization(e.to_string()))?;
let mut keys = self.keys.write().await;
keys.insert(key.id.clone(), serialized);
Ok(())
}
async fn get_key(&self, key_id: &str) -> StorageResult<StoredKey> {
let keys = self.keys.read().await;
match keys.get(key_id) {
Some(data) => {
serde_json::from_slice(data).map_err(|e| StorageError::Serialization(e.to_string()))
}
None => Err(StorageError::NotFound(key_id.to_string())),
}
}
async fn list_keys(&self) -> StorageResult<Vec<String>> {
let keys = self.keys.read().await;
let results: Vec<String> = keys.keys().cloned().collect();
Ok(results)
}
async fn store_policy(&self, name: &str, policy: &StoredPolicy) -> StorageResult<()> {
let serialized =
serde_json::to_vec(&policy).map_err(|e| StorageError::Serialization(e.to_string()))?;
let mut policies = self.policies.write().await;
policies.insert(name.to_string(), serialized);
Ok(())
}
async fn get_policy(&self, name: &str) -> StorageResult<StoredPolicy> {
let policies = self.policies.read().await;
match policies.get(name) {
Some(data) => {
serde_json::from_slice(data).map_err(|e| StorageError::Serialization(e.to_string()))
}
None => Err(StorageError::NotFound(name.to_string())),
}
}
async fn list_policies(&self) -> StorageResult<Vec<String>> {
let policies = self.policies.read().await;
let results: Vec<String> = policies.keys().cloned().collect();
Ok(results)
}
async fn store_lease(&self, lease: &Lease) -> StorageResult<()> {
let serialized =
serde_json::to_vec(&lease).map_err(|e| StorageError::Serialization(e.to_string()))?;
let mut leases = self.leases.write().await;
leases.insert(lease.id.clone(), serialized);
Ok(())
}
async fn get_lease(&self, lease_id: &str) -> StorageResult<Lease> {
let leases = self.leases.read().await;
match leases.get(lease_id) {
Some(data) => {
serde_json::from_slice(data).map_err(|e| StorageError::Serialization(e.to_string()))
}
None => Err(StorageError::NotFound(lease_id.to_string())),
}
}
async fn delete_lease(&self, lease_id: &str) -> StorageResult<()> {
let mut leases = self.leases.write().await;
leases.remove(lease_id);
Ok(())
}
async fn list_expiring_leases(&self, before: DateTime<Utc>) -> StorageResult<Vec<Lease>> {
let leases = self.leases.read().await;
let mut results = Vec::new();
for data in leases.values() {
if let Ok(lease) = serde_json::from_slice::<Lease>(data) {
if lease.expires_at <= before {
results.push(lease);
}
}
}
Ok(results)
}
async fn health_check(&self) -> StorageResult<()> {
// Simple check: verify we can access the storage
let _secrets = self.secrets.read().await;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_postgresql_backend_creation() -> std::result::Result<(), StorageError> {
let config = PostgreSQLStorageConfig::default();
let backend = PostgreSQLBackend::new(&config)?;
backend.health_check().await?;
Ok(())
}
#[tokio::test]
async fn test_postgresql_invalid_connection_string() {
let config = PostgreSQLStorageConfig {
connection_string: "invalid://string".to_string(),
};
assert!(PostgreSQLBackend::new(&config).is_err());
}
#[tokio::test]
async fn test_postgresql_store_and_get_secret() -> std::result::Result<(), StorageError> {
let config = PostgreSQLStorageConfig::default();
let backend = PostgreSQLBackend::new(&config)?;
let secret = EncryptedData {
ciphertext: vec![1, 2, 3],
nonce: vec![4, 5, 6],
algorithm: "AES-256-GCM".to_string(),
};
backend.store_secret("test/secret", &secret).await?;
let retrieved = backend.get_secret("test/secret").await?;
assert_eq!(retrieved.ciphertext, secret.ciphertext);
assert_eq!(retrieved.algorithm, secret.algorithm);
Ok(())
}
}