//! MFA device storage and persistence use async_trait::async_trait; use chrono::{DateTime, Utc}; use sqlx::Row; use crate::error::{infrastructure, ControlCenterError, Result}; use crate::mfa::types::{BackupCode, TotpDevice, WebAuthnDevice}; use crate::storage::database::Database; /// MFA storage trait #[async_trait] pub trait MfaStorage: Send + Sync { // TOTP device operations async fn create_totp_device(&self, device: &TotpDevice) -> Result<()>; async fn get_totp_device(&self, device_id: &str) -> Result>; async fn get_totp_devices_by_user(&self, user_id: &str) -> Result>; async fn update_totp_device(&self, device: &TotpDevice) -> Result<()>; async fn delete_totp_device(&self, device_id: &str) -> Result<()>; // WebAuthn device operations async fn create_webauthn_device(&self, device: &WebAuthnDevice) -> Result<()>; async fn get_webauthn_device(&self, device_id: &str) -> Result>; async fn get_webauthn_devices_by_user(&self, user_id: &str) -> Result>; async fn update_webauthn_device(&self, device: &WebAuthnDevice) -> Result<()>; async fn delete_webauthn_device(&self, device_id: &str) -> Result<()>; // Generic operations async fn get_all_devices_by_user( &self, user_id: &str, ) -> Result<(Vec, Vec)>; async fn delete_all_devices_by_user(&self, user_id: &str) -> Result<()>; async fn user_has_mfa(&self, user_id: &str) -> Result; } /// SQLite-based MFA storage implementation pub struct SqliteMfaStorage { db: Database, } impl SqliteMfaStorage { /// Create new SQLite storage pub async fn new(db: Database) -> Result { let storage = Self { db }; storage.initialize().await?; Ok(storage) } /// Initialize database tables async fn initialize(&self) -> Result<()> { // Create TOTP devices table sqlx::query( r#" CREATE TABLE IF NOT EXISTS mfa_totp_devices ( id TEXT PRIMARY KEY, user_id TEXT NOT NULL, secret TEXT NOT NULL, algorithm TEXT NOT NULL, digits INTEGER NOT NULL, period INTEGER NOT NULL, created_at TEXT NOT NULL, last_used TEXT, enabled INTEGER NOT NULL, FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE ) "#, ) .execute(self.db.pool()) .await .map_err(|e| { ControlCenterError::from(infrastructure::internal_error(&format!( "Failed to create totp_devices table: {}", e ))) })?; // Create backup codes table sqlx::query( r#" CREATE TABLE IF NOT EXISTS mfa_backup_codes ( id INTEGER PRIMARY KEY AUTOINCREMENT, device_id TEXT NOT NULL, code_hash TEXT NOT NULL, used INTEGER NOT NULL, used_at TEXT, FOREIGN KEY (device_id) REFERENCES mfa_totp_devices(id) ON DELETE CASCADE ) "#, ) .execute(self.db.pool()) .await .map_err(|e| { ControlCenterError::from(infrastructure::internal_error(&format!( "Failed to create backup_codes table: {}", e ))) })?; // Create WebAuthn devices table sqlx::query( r#" CREATE TABLE IF NOT EXISTS mfa_webauthn_devices ( id TEXT PRIMARY KEY, user_id TEXT NOT NULL, credential_id BLOB NOT NULL, public_key BLOB NOT NULL, counter INTEGER NOT NULL, device_name TEXT NOT NULL, created_at TEXT NOT NULL, last_used TEXT, enabled INTEGER NOT NULL, attestation_type TEXT, transports TEXT, FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE ) "#, ) .execute(self.db.pool()) .await .map_err(|e| { ControlCenterError::from(infrastructure::internal_error(&format!( "Failed to create webauthn_devices table: {}", e ))) })?; // Create indices sqlx::query("CREATE INDEX IF NOT EXISTS idx_totp_user_id ON mfa_totp_devices(user_id)") .execute(self.db.pool()) .await .ok(); sqlx::query( "CREATE INDEX IF NOT EXISTS idx_webauthn_user_id ON mfa_webauthn_devices(user_id)", ) .execute(self.db.pool()) .await .ok(); sqlx::query( "CREATE INDEX IF NOT EXISTS idx_backup_device_id ON mfa_backup_codes(device_id)", ) .execute(self.db.pool()) .await .ok(); Ok(()) } } #[async_trait] impl MfaStorage for SqliteMfaStorage { async fn create_totp_device(&self, device: &TotpDevice) -> Result<()> { let mut tx = self.db.pool().begin().await.map_err(|e| { ControlCenterError::from(infrastructure::internal_error(&format!( "Failed to start transaction: {}", e ))) })?; // Insert device sqlx::query( r#" INSERT INTO mfa_totp_devices (id, user_id, secret, algorithm, digits, period, created_at, last_used, enabled) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) "#, ) .bind(&device.id) .bind(&device.user_id) .bind(&device.secret) .bind(format!("{:?}", device.algorithm)) .bind(device.digits as i64) .bind(device.period as i64) .bind(device.created_at.to_rfc3339()) .bind(device.last_used.map(|dt| dt.to_rfc3339())) .bind(device.enabled as i64) .execute(&mut *tx) .await .map_err(|e| { ControlCenterError::from(infrastructure::internal_error(&format!( "Failed to insert TOTP device: {}", e ))) })?; // Insert backup codes for backup_code in &device.backup_codes { sqlx::query( r#" INSERT INTO mfa_backup_codes (device_id, code_hash, used, used_at) VALUES (?, ?, ?, ?) "#, ) .bind(&device.id) .bind(&backup_code.code_hash) .bind(backup_code.used as i64) .bind(backup_code.used_at.map(|dt| dt.to_rfc3339())) .execute(&mut *tx) .await .map_err(|e| { ControlCenterError::from(infrastructure::internal_error(&format!( "Failed to insert backup code: {}", e ))) })?; } tx.commit().await.map_err(|e| { ControlCenterError::from(infrastructure::internal_error(&format!( "Failed to commit transaction: {}", e ))) })?; Ok(()) } async fn get_totp_device(&self, device_id: &str) -> Result> { let row = sqlx::query( r#" SELECT id, user_id, secret, algorithm, digits, period, created_at, last_used, enabled FROM mfa_totp_devices WHERE id = ? "#, ) .bind(device_id) .fetch_optional(self.db.pool()) .await .map_err(|e| { ControlCenterError::from(infrastructure::internal_error(&format!( "Failed to fetch TOTP device: {}", e ))) })?; if let Some(row) = row { let id: String = row.get("id"); let user_id: String = row.get("user_id"); let secret: String = row.get("secret"); let algorithm: String = row.get("algorithm"); let digits: i64 = row.get("digits"); let period: i64 = row.get("period"); let created_at: String = row.get("created_at"); let last_used: Option = row.get("last_used"); let enabled: i64 = row.get("enabled"); // Fetch backup codes let backup_rows = sqlx::query( r#" SELECT code_hash, used, used_at FROM mfa_backup_codes WHERE device_id = ? "#, ) .bind(&id) .fetch_all(self.db.pool()) .await .map_err(|e| { ControlCenterError::from(infrastructure::internal_error(&format!( "Failed to fetch backup codes: {}", e ))) })?; let backup_codes: Vec = backup_rows .iter() .map(|row| { let code_hash: String = row.get("code_hash"); let used: i64 = row.get("used"); let used_at: Option = row.get("used_at"); BackupCode { code: None, code_hash, used: used != 0, used_at: used_at.and_then(|s| { DateTime::parse_from_rfc3339(&s) .ok() .map(|dt| dt.with_timezone(&Utc)) }), } }) .collect(); let algorithm = match algorithm.as_str() { "Sha1" => crate::mfa::types::TotpAlgorithm::Sha1, "Sha256" => crate::mfa::types::TotpAlgorithm::Sha256, "Sha512" => crate::mfa::types::TotpAlgorithm::Sha512, _ => crate::mfa::types::TotpAlgorithm::Sha1, }; let device = TotpDevice { id, user_id, secret, algorithm, digits: digits as u32, period: period as u32, created_at: DateTime::parse_from_rfc3339(&created_at) .map_err(|e| { ControlCenterError::from(infrastructure::internal_error(&format!( "Invalid datetime: {}", e ))) })? .with_timezone(&Utc), last_used: last_used .and_then(|s| DateTime::parse_from_rfc3339(&s).ok()) .map(|dt| dt.with_timezone(&Utc)), backup_codes, enabled: enabled != 0, }; Ok(Some(device)) } else { Ok(None) } } async fn get_totp_devices_by_user(&self, user_id: &str) -> Result> { let rows = sqlx::query( r#" SELECT id, user_id, secret, algorithm, digits, period, created_at, last_used, enabled FROM mfa_totp_devices WHERE user_id = ? ORDER BY created_at DESC "#, ) .bind(user_id) .fetch_all(self.db.pool()) .await .map_err(|e| { ControlCenterError::from(infrastructure::internal_error(&format!( "Failed to fetch TOTP devices: {}", e ))) })?; let mut devices = Vec::new(); for row in rows { let device_id: String = row.get("id"); if let Some(device) = self.get_totp_device(&device_id).await? { devices.push(device); } } Ok(devices) } async fn update_totp_device(&self, device: &TotpDevice) -> Result<()> { let mut tx = self.db.pool().begin().await.map_err(|e| { ControlCenterError::from(infrastructure::internal_error(&format!( "Failed to start transaction: {}", e ))) })?; // Update device sqlx::query( r#" UPDATE mfa_totp_devices SET last_used = ?, enabled = ? WHERE id = ? "#, ) .bind(device.last_used.map(|dt| dt.to_rfc3339())) .bind(device.enabled as i64) .bind(&device.id) .execute(&mut *tx) .await .map_err(|e| { ControlCenterError::from(infrastructure::internal_error(&format!( "Failed to update TOTP device: {}", e ))) })?; // Update backup codes sqlx::query("DELETE FROM mfa_backup_codes WHERE device_id = ?") .bind(&device.id) .execute(&mut *tx) .await .ok(); for backup_code in &device.backup_codes { sqlx::query( r#" INSERT INTO mfa_backup_codes (device_id, code_hash, used, used_at) VALUES (?, ?, ?, ?) "#, ) .bind(&device.id) .bind(&backup_code.code_hash) .bind(backup_code.used as i64) .bind(backup_code.used_at.map(|dt| dt.to_rfc3339())) .execute(&mut *tx) .await .map_err(|e| { ControlCenterError::from(infrastructure::internal_error(&format!( "Failed to insert backup code: {}", e ))) })?; } tx.commit().await.map_err(|e| { ControlCenterError::from(infrastructure::internal_error(&format!( "Failed to commit transaction: {}", e ))) })?; Ok(()) } async fn delete_totp_device(&self, device_id: &str) -> Result<()> { sqlx::query("DELETE FROM mfa_totp_devices WHERE id = ?") .bind(device_id) .execute(self.db.pool()) .await .map_err(|e| { ControlCenterError::from(infrastructure::internal_error(&format!( "Failed to delete TOTP device: {}", e ))) })?; Ok(()) } async fn create_webauthn_device(&self, device: &WebAuthnDevice) -> Result<()> { let transports = serde_json::to_string(&device.transports).map_err(|e| { ControlCenterError::from(infrastructure::internal_error(&format!( "Failed to serialize transports: {}", e ))) })?; sqlx::query( r#" INSERT INTO mfa_webauthn_devices (id, user_id, credential_id, public_key, counter, device_name, created_at, last_used, enabled, attestation_type, transports) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) "#, ) .bind(&device.id) .bind(&device.user_id) .bind(&device.credential_id) .bind(&device.public_key) .bind(device.counter as i64) .bind(&device.device_name) .bind(device.created_at.to_rfc3339()) .bind(device.last_used.map(|dt| dt.to_rfc3339())) .bind(device.enabled as i64) .bind(&device.attestation_type) .bind(transports) .execute(self.db.pool()) .await .map_err(|e| ControlCenterError::from(infrastructure::internal_error(&format!("Failed to insert WebAuthn device: {}", e))))?; Ok(()) } async fn get_webauthn_device(&self, device_id: &str) -> Result> { let row = sqlx::query( r#" SELECT id, user_id, credential_id, public_key, counter, device_name, created_at, last_used, enabled, attestation_type, transports FROM mfa_webauthn_devices WHERE id = ? "#, ) .bind(device_id) .fetch_optional(self.db.pool()) .await .map_err(|e| ControlCenterError::from(infrastructure::internal_error(&format!("Failed to fetch WebAuthn device: {}", e))))?; if let Some(row) = row { let id: String = row.get("id"); let user_id: String = row.get("user_id"); let credential_id: Vec = row.get("credential_id"); let public_key: Vec = row.get("public_key"); let counter: i64 = row.get("counter"); let device_name: String = row.get("device_name"); let created_at: String = row.get("created_at"); let last_used: Option = row.get("last_used"); let enabled: i64 = row.get("enabled"); let attestation_type: Option = row.get("attestation_type"); let transports_json: String = row.get("transports"); let transports: Vec = serde_json::from_str(&transports_json).map_err(|e| { ControlCenterError::from(infrastructure::internal_error(&format!( "Failed to deserialize transports: {}", e ))) })?; let device = WebAuthnDevice { id, user_id, credential_id, public_key, counter: counter as u32, device_name, created_at: DateTime::parse_from_rfc3339(&created_at) .map_err(|e| { ControlCenterError::from(infrastructure::internal_error(&format!( "Invalid datetime: {}", e ))) })? .with_timezone(&Utc), last_used: last_used .and_then(|s| DateTime::parse_from_rfc3339(&s).ok()) .map(|dt| dt.with_timezone(&Utc)), enabled: enabled != 0, attestation_type, transports, }; Ok(Some(device)) } else { Ok(None) } } async fn get_webauthn_devices_by_user(&self, user_id: &str) -> Result> { let rows = sqlx::query( r#" SELECT id FROM mfa_webauthn_devices WHERE user_id = ? ORDER BY created_at DESC "#, ) .bind(user_id) .fetch_all(self.db.pool()) .await .map_err(|e| { ControlCenterError::from(infrastructure::internal_error(&format!( "Failed to fetch WebAuthn devices: {}", e ))) })?; let mut devices = Vec::new(); for row in rows { let device_id: String = row.get("id"); if let Some(device) = self.get_webauthn_device(&device_id).await? { devices.push(device); } } Ok(devices) } async fn update_webauthn_device(&self, device: &WebAuthnDevice) -> Result<()> { let transports = serde_json::to_string(&device.transports).map_err(|e| { ControlCenterError::from(infrastructure::internal_error(&format!( "Failed to serialize transports: {}", e ))) })?; sqlx::query( r#" UPDATE mfa_webauthn_devices SET counter = ?, last_used = ?, enabled = ?, transports = ? WHERE id = ? "#, ) .bind(device.counter as i64) .bind(device.last_used.map(|dt| dt.to_rfc3339())) .bind(device.enabled as i64) .bind(transports) .bind(&device.id) .execute(self.db.pool()) .await .map_err(|e| { ControlCenterError::from(infrastructure::internal_error(&format!( "Failed to update WebAuthn device: {}", e ))) })?; Ok(()) } async fn delete_webauthn_device(&self, device_id: &str) -> Result<()> { sqlx::query("DELETE FROM mfa_webauthn_devices WHERE id = ?") .bind(device_id) .execute(self.db.pool()) .await .map_err(|e| { ControlCenterError::from(infrastructure::internal_error(&format!( "Failed to delete WebAuthn device: {}", e ))) })?; Ok(()) } async fn get_all_devices_by_user( &self, user_id: &str, ) -> Result<(Vec, Vec)> { let totp_devices = self.get_totp_devices_by_user(user_id).await?; let webauthn_devices = self.get_webauthn_devices_by_user(user_id).await?; Ok((totp_devices, webauthn_devices)) } async fn delete_all_devices_by_user(&self, user_id: &str) -> Result<()> { let mut tx = self.db.pool().begin().await.map_err(|e| { ControlCenterError::from(infrastructure::internal_error(&format!( "Failed to start transaction: {}", e ))) })?; sqlx::query("DELETE FROM mfa_totp_devices WHERE user_id = ?") .bind(user_id) .execute(&mut *tx) .await .ok(); sqlx::query("DELETE FROM mfa_webauthn_devices WHERE user_id = ?") .bind(user_id) .execute(&mut *tx) .await .ok(); tx.commit().await.map_err(|e| { ControlCenterError::from(infrastructure::internal_error(&format!( "Failed to commit transaction: {}", e ))) })?; Ok(()) } async fn user_has_mfa(&self, user_id: &str) -> Result { let totp_count: i64 = sqlx::query_scalar( "SELECT COUNT(*) FROM mfa_totp_devices WHERE user_id = ? AND enabled = 1", ) .bind(user_id) .fetch_one(self.db.pool()) .await .unwrap_or(0); let webauthn_count: i64 = sqlx::query_scalar( "SELECT COUNT(*) FROM mfa_webauthn_devices WHERE user_id = ? AND enabled = 1", ) .bind(user_id) .fetch_one(self.db.pool()) .await .unwrap_or(0); Ok(totp_count > 0 || webauthn_count > 0) } } #[cfg(test)] mod tests { use super::*; use crate::storage::database::DatabaseConfig; async fn create_test_storage() -> SqliteMfaStorage { let config = DatabaseConfig { url: ":memory:".to_string(), max_connections: 5, }; let db = Database::new(config).await.unwrap(); // Create users table for foreign key sqlx::query( r#" CREATE TABLE users ( id TEXT PRIMARY KEY, email TEXT NOT NULL UNIQUE, name TEXT NOT NULL ) "#, ) .execute(db.pool()) .await .unwrap(); SqliteMfaStorage::new(db).await.unwrap() } #[tokio::test] async fn test_totp_device_crud() { let storage = create_test_storage().await; // Create test user let user_id = "test_user"; sqlx::query("INSERT INTO users (id, email, name) VALUES (?, ?, ?)") .bind(user_id) .bind("test@example.com") .bind("Test User") .execute(storage.db.pool()) .await .unwrap(); // Create device let device = TotpDevice::new(user_id.to_string(), "SECRET123".to_string()); let device_id = device.id.clone(); storage.create_totp_device(&device).await.unwrap(); // Read device let loaded = storage.get_totp_device(&device_id).await.unwrap(); assert!(loaded.is_some()); let loaded = loaded.unwrap(); assert_eq!(loaded.user_id, user_id); assert_eq!(loaded.secret, "SECRET123"); // Update device let mut updated = loaded.clone(); updated.enabled = true; storage.update_totp_device(&updated).await.unwrap(); let reloaded = storage.get_totp_device(&device_id).await.unwrap().unwrap(); assert!(reloaded.enabled); // Delete device storage.delete_totp_device(&device_id).await.unwrap(); let deleted = storage.get_totp_device(&device_id).await.unwrap(); assert!(deleted.is_none()); } #[tokio::test] async fn test_webauthn_device_crud() { let storage = create_test_storage().await; // Create test user let user_id = "test_user"; sqlx::query("INSERT INTO users (id, email, name) VALUES (?, ?, ?)") .bind(user_id) .bind("test@example.com") .bind("Test User") .execute(storage.db.pool()) .await .unwrap(); // Create device let device = WebAuthnDevice::new( user_id.to_string(), vec![1, 2, 3], vec![4, 5, 6], "YubiKey".to_string(), ); let device_id = device.id.clone(); storage.create_webauthn_device(&device).await.unwrap(); // Read device let loaded = storage.get_webauthn_device(&device_id).await.unwrap(); assert!(loaded.is_some()); let loaded = loaded.unwrap(); assert_eq!(loaded.user_id, user_id); assert_eq!(loaded.device_name, "YubiKey"); // Delete device storage.delete_webauthn_device(&device_id).await.unwrap(); let deleted = storage.get_webauthn_device(&device_id).await.unwrap(); assert!(deleted.is_none()); } #[tokio::test] async fn test_user_has_mfa() { let storage = create_test_storage().await; // Create test user let user_id = "test_user"; sqlx::query("INSERT INTO users (id, email, name) VALUES (?, ?, ?)") .bind(user_id) .bind("test@example.com") .bind("Test User") .execute(storage.db.pool()) .await .unwrap(); // Initially no MFA assert!(!storage.user_has_mfa(user_id).await.unwrap()); // Add TOTP device let mut device = TotpDevice::new(user_id.to_string(), "SECRET123".to_string()); device.enabled = true; storage.create_totp_device(&device).await.unwrap(); // Now has MFA assert!(storage.user_has_mfa(user_id).await.unwrap()); } }