//! PostgreSQL storage backend for SecretumVault //! //! Provides persistent secret storage using PostgreSQL as the backend. //! This implementation uses an in-memory store (production would use sqlx + real DB). use async_trait::async_trait; use chrono::{DateTime, Utc}; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::RwLock; use crate::config::PostgreSQLStorageConfig; use crate::error::{StorageError, StorageResult}; use crate::storage::{EncryptedData, Lease, StorageBackend, StoredKey, StoredPolicy}; /// PostgreSQL storage backend for secrets persistence pub struct PostgreSQLBackend { // In-memory storage (production would use actual PostgreSQL) secrets: Arc>>>, keys: Arc>>>, policies: Arc>>>, leases: Arc>>>, connection_string: String, } impl PostgreSQLBackend { /// Create a new PostgreSQL backend instance pub fn new(config: &PostgreSQLStorageConfig) -> std::result::Result { if !config.connection_string.starts_with("postgres://") { return Err(StorageError::Internal( "Invalid PostgreSQL connection string".to_string(), )); } Ok(Self { secrets: Arc::new(RwLock::new(HashMap::new())), keys: Arc::new(RwLock::new(HashMap::new())), policies: Arc::new(RwLock::new(HashMap::new())), leases: Arc::new(RwLock::new(HashMap::new())), connection_string: config.connection_string.clone(), }) } } impl std::fmt::Debug for PostgreSQLBackend { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("PostgreSQLBackend") .field("connection_string", &self.connection_string) .finish() } } #[async_trait] impl StorageBackend for PostgreSQLBackend { async fn store_secret(&self, path: &str, data: &EncryptedData) -> StorageResult<()> { let serialized = serde_json::to_vec(&data).map_err(|e| StorageError::Serialization(e.to_string()))?; let mut secrets = self.secrets.write().await; secrets.insert(path.to_string(), serialized); Ok(()) } async fn get_secret(&self, path: &str) -> StorageResult { let secrets = self.secrets.read().await; match secrets.get(path) { Some(data) => { serde_json::from_slice(data).map_err(|e| StorageError::Serialization(e.to_string())) } None => Err(StorageError::NotFound(path.to_string())), } } async fn delete_secret(&self, path: &str) -> StorageResult<()> { let mut secrets = self.secrets.write().await; secrets.remove(path); Ok(()) } async fn list_secrets(&self, prefix: &str) -> StorageResult> { let secrets = self.secrets.read().await; let results: Vec = secrets .keys() .filter(|k| k.starts_with(prefix)) .cloned() .collect(); Ok(results) } async fn store_key(&self, key: &StoredKey) -> StorageResult<()> { let serialized = serde_json::to_vec(&key).map_err(|e| StorageError::Serialization(e.to_string()))?; let mut keys = self.keys.write().await; keys.insert(key.id.clone(), serialized); Ok(()) } async fn get_key(&self, key_id: &str) -> StorageResult { let keys = self.keys.read().await; match keys.get(key_id) { Some(data) => { serde_json::from_slice(data).map_err(|e| StorageError::Serialization(e.to_string())) } None => Err(StorageError::NotFound(key_id.to_string())), } } async fn list_keys(&self) -> StorageResult> { let keys = self.keys.read().await; let results: Vec = keys.keys().cloned().collect(); Ok(results) } async fn store_policy(&self, name: &str, policy: &StoredPolicy) -> StorageResult<()> { let serialized = serde_json::to_vec(&policy).map_err(|e| StorageError::Serialization(e.to_string()))?; let mut policies = self.policies.write().await; policies.insert(name.to_string(), serialized); Ok(()) } async fn get_policy(&self, name: &str) -> StorageResult { let policies = self.policies.read().await; match policies.get(name) { Some(data) => { serde_json::from_slice(data).map_err(|e| StorageError::Serialization(e.to_string())) } None => Err(StorageError::NotFound(name.to_string())), } } async fn list_policies(&self) -> StorageResult> { let policies = self.policies.read().await; let results: Vec = policies.keys().cloned().collect(); Ok(results) } async fn store_lease(&self, lease: &Lease) -> StorageResult<()> { let serialized = serde_json::to_vec(&lease).map_err(|e| StorageError::Serialization(e.to_string()))?; let mut leases = self.leases.write().await; leases.insert(lease.id.clone(), serialized); Ok(()) } async fn get_lease(&self, lease_id: &str) -> StorageResult { let leases = self.leases.read().await; match leases.get(lease_id) { Some(data) => { serde_json::from_slice(data).map_err(|e| StorageError::Serialization(e.to_string())) } None => Err(StorageError::NotFound(lease_id.to_string())), } } async fn delete_lease(&self, lease_id: &str) -> StorageResult<()> { let mut leases = self.leases.write().await; leases.remove(lease_id); Ok(()) } async fn list_expiring_leases(&self, before: DateTime) -> StorageResult> { let leases = self.leases.read().await; let mut results = Vec::new(); for data in leases.values() { if let Ok(lease) = serde_json::from_slice::(data) { if lease.expires_at <= before { results.push(lease); } } } Ok(results) } async fn health_check(&self) -> StorageResult<()> { // Simple check: verify we can access the storage let _secrets = self.secrets.read().await; Ok(()) } } #[cfg(test)] mod tests { use super::*; #[tokio::test] async fn test_postgresql_backend_creation() -> std::result::Result<(), StorageError> { let config = PostgreSQLStorageConfig::default(); let backend = PostgreSQLBackend::new(&config)?; backend.health_check().await?; Ok(()) } #[tokio::test] async fn test_postgresql_invalid_connection_string() { let config = PostgreSQLStorageConfig { connection_string: "invalid://string".to_string(), }; assert!(PostgreSQLBackend::new(&config).is_err()); } #[tokio::test] async fn test_postgresql_store_and_get_secret() -> std::result::Result<(), StorageError> { let config = PostgreSQLStorageConfig::default(); let backend = PostgreSQLBackend::new(&config)?; let secret = EncryptedData { ciphertext: vec![1, 2, 3], nonce: vec![4, 5, 6], algorithm: "AES-256-GCM".to_string(), }; backend.store_secret("test/secret", &secret).await?; let retrieved = backend.get_secret("test/secret").await?; assert_eq!(retrieved.ciphertext, secret.ciphertext); assert_eq!(retrieved.algorithm, secret.algorithm); Ok(()) } }