799 lines
26 KiB
Rust
Raw Normal View History

//! MFA device storage and persistence
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use sqlx::Row;
use crate::error::{infrastructure, ControlCenterError, Result};
use crate::mfa::types::{BackupCode, TotpDevice, WebAuthnDevice};
use crate::storage::database::Database;
/// MFA storage trait
#[async_trait]
pub trait MfaStorage: Send + Sync {
// TOTP device operations
async fn create_totp_device(&self, device: &TotpDevice) -> Result<()>;
async fn get_totp_device(&self, device_id: &str) -> Result<Option<TotpDevice>>;
async fn get_totp_devices_by_user(&self, user_id: &str) -> Result<Vec<TotpDevice>>;
async fn update_totp_device(&self, device: &TotpDevice) -> Result<()>;
async fn delete_totp_device(&self, device_id: &str) -> Result<()>;
// WebAuthn device operations
async fn create_webauthn_device(&self, device: &WebAuthnDevice) -> Result<()>;
async fn get_webauthn_device(&self, device_id: &str) -> Result<Option<WebAuthnDevice>>;
async fn get_webauthn_devices_by_user(&self, user_id: &str) -> Result<Vec<WebAuthnDevice>>;
async fn update_webauthn_device(&self, device: &WebAuthnDevice) -> Result<()>;
async fn delete_webauthn_device(&self, device_id: &str) -> Result<()>;
// Generic operations
async fn get_all_devices_by_user(
&self,
user_id: &str,
) -> Result<(Vec<TotpDevice>, Vec<WebAuthnDevice>)>;
async fn delete_all_devices_by_user(&self, user_id: &str) -> Result<()>;
async fn user_has_mfa(&self, user_id: &str) -> Result<bool>;
}
/// SQLite-based MFA storage implementation
pub struct SqliteMfaStorage {
db: Database,
}
impl SqliteMfaStorage {
/// Create new SQLite storage
pub async fn new(db: Database) -> Result<Self> {
let storage = Self { db };
storage.initialize().await?;
Ok(storage)
}
/// Initialize database tables
async fn initialize(&self) -> Result<()> {
// Create TOTP devices table
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS mfa_totp_devices (
id TEXT PRIMARY KEY,
user_id TEXT NOT NULL,
secret TEXT NOT NULL,
algorithm TEXT NOT NULL,
digits INTEGER NOT NULL,
period INTEGER NOT NULL,
created_at TEXT NOT NULL,
last_used TEXT,
enabled INTEGER NOT NULL,
FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE
)
"#,
)
.execute(self.db.pool())
.await
.map_err(|e| {
ControlCenterError::from(infrastructure::internal_error(&format!(
"Failed to create totp_devices table: {}",
e
)))
})?;
// Create backup codes table
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS mfa_backup_codes (
id INTEGER PRIMARY KEY AUTOINCREMENT,
device_id TEXT NOT NULL,
code_hash TEXT NOT NULL,
used INTEGER NOT NULL,
used_at TEXT,
FOREIGN KEY (device_id) REFERENCES mfa_totp_devices(id) ON DELETE CASCADE
)
"#,
)
.execute(self.db.pool())
.await
.map_err(|e| {
ControlCenterError::from(infrastructure::internal_error(&format!(
"Failed to create backup_codes table: {}",
e
)))
})?;
// Create WebAuthn devices table
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS mfa_webauthn_devices (
id TEXT PRIMARY KEY,
user_id TEXT NOT NULL,
credential_id BLOB NOT NULL,
public_key BLOB NOT NULL,
counter INTEGER NOT NULL,
device_name TEXT NOT NULL,
created_at TEXT NOT NULL,
last_used TEXT,
enabled INTEGER NOT NULL,
attestation_type TEXT,
transports TEXT,
FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE
)
"#,
)
.execute(self.db.pool())
.await
.map_err(|e| {
ControlCenterError::from(infrastructure::internal_error(&format!(
"Failed to create webauthn_devices table: {}",
e
)))
})?;
// Create indices
sqlx::query("CREATE INDEX IF NOT EXISTS idx_totp_user_id ON mfa_totp_devices(user_id)")
.execute(self.db.pool())
.await
.ok();
sqlx::query(
"CREATE INDEX IF NOT EXISTS idx_webauthn_user_id ON mfa_webauthn_devices(user_id)",
)
.execute(self.db.pool())
.await
.ok();
sqlx::query(
"CREATE INDEX IF NOT EXISTS idx_backup_device_id ON mfa_backup_codes(device_id)",
)
.execute(self.db.pool())
.await
.ok();
Ok(())
}
}
#[async_trait]
impl MfaStorage for SqliteMfaStorage {
async fn create_totp_device(&self, device: &TotpDevice) -> Result<()> {
let mut tx = self.db.pool().begin().await.map_err(|e| {
ControlCenterError::from(infrastructure::internal_error(&format!(
"Failed to start transaction: {}",
e
)))
})?;
// Insert device
sqlx::query(
r#"
INSERT INTO mfa_totp_devices
(id, user_id, secret, algorithm, digits, period, created_at, last_used, enabled)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
"#,
)
.bind(&device.id)
.bind(&device.user_id)
.bind(&device.secret)
.bind(format!("{:?}", device.algorithm))
.bind(device.digits as i64)
.bind(device.period as i64)
.bind(device.created_at.to_rfc3339())
.bind(device.last_used.map(|dt| dt.to_rfc3339()))
.bind(device.enabled as i64)
.execute(&mut *tx)
.await
.map_err(|e| {
ControlCenterError::from(infrastructure::internal_error(&format!(
"Failed to insert TOTP device: {}",
e
)))
})?;
// Insert backup codes
for backup_code in &device.backup_codes {
sqlx::query(
r#"
INSERT INTO mfa_backup_codes (device_id, code_hash, used, used_at)
VALUES (?, ?, ?, ?)
"#,
)
.bind(&device.id)
.bind(&backup_code.code_hash)
.bind(backup_code.used as i64)
.bind(backup_code.used_at.map(|dt| dt.to_rfc3339()))
.execute(&mut *tx)
.await
.map_err(|e| {
ControlCenterError::from(infrastructure::internal_error(&format!(
"Failed to insert backup code: {}",
e
)))
})?;
}
tx.commit().await.map_err(|e| {
ControlCenterError::from(infrastructure::internal_error(&format!(
"Failed to commit transaction: {}",
e
)))
})?;
Ok(())
}
async fn get_totp_device(&self, device_id: &str) -> Result<Option<TotpDevice>> {
let row = sqlx::query(
r#"
SELECT id, user_id, secret, algorithm, digits, period, created_at, last_used, enabled
FROM mfa_totp_devices
WHERE id = ?
"#,
)
.bind(device_id)
.fetch_optional(self.db.pool())
.await
.map_err(|e| {
ControlCenterError::from(infrastructure::internal_error(&format!(
"Failed to fetch TOTP device: {}",
e
)))
})?;
if let Some(row) = row {
let id: String = row.get("id");
let user_id: String = row.get("user_id");
let secret: String = row.get("secret");
let algorithm: String = row.get("algorithm");
let digits: i64 = row.get("digits");
let period: i64 = row.get("period");
let created_at: String = row.get("created_at");
let last_used: Option<String> = row.get("last_used");
let enabled: i64 = row.get("enabled");
// Fetch backup codes
let backup_rows = sqlx::query(
r#"
SELECT code_hash, used, used_at
FROM mfa_backup_codes
WHERE device_id = ?
"#,
)
.bind(&id)
.fetch_all(self.db.pool())
.await
.map_err(|e| {
ControlCenterError::from(infrastructure::internal_error(&format!(
"Failed to fetch backup codes: {}",
e
)))
})?;
let backup_codes: Vec<BackupCode> = backup_rows
.iter()
.map(|row| {
let code_hash: String = row.get("code_hash");
let used: i64 = row.get("used");
let used_at: Option<String> = row.get("used_at");
BackupCode {
code: None,
code_hash,
used: used != 0,
used_at: used_at.and_then(|s| {
DateTime::parse_from_rfc3339(&s)
.ok()
.map(|dt| dt.with_timezone(&Utc))
}),
}
})
.collect();
let algorithm = match algorithm.as_str() {
"Sha1" => crate::mfa::types::TotpAlgorithm::Sha1,
"Sha256" => crate::mfa::types::TotpAlgorithm::Sha256,
"Sha512" => crate::mfa::types::TotpAlgorithm::Sha512,
_ => crate::mfa::types::TotpAlgorithm::Sha1,
};
let device = TotpDevice {
id,
user_id,
secret,
algorithm,
digits: digits as u32,
period: period as u32,
created_at: DateTime::parse_from_rfc3339(&created_at)
.map_err(|e| {
ControlCenterError::from(infrastructure::internal_error(&format!(
"Invalid datetime: {}",
e
)))
})?
.with_timezone(&Utc),
last_used: last_used
.and_then(|s| DateTime::parse_from_rfc3339(&s).ok())
.map(|dt| dt.with_timezone(&Utc)),
backup_codes,
enabled: enabled != 0,
};
Ok(Some(device))
} else {
Ok(None)
}
}
async fn get_totp_devices_by_user(&self, user_id: &str) -> Result<Vec<TotpDevice>> {
let rows = sqlx::query(
r#"
SELECT id, user_id, secret, algorithm, digits, period, created_at, last_used, enabled
FROM mfa_totp_devices
WHERE user_id = ?
ORDER BY created_at DESC
"#,
)
.bind(user_id)
.fetch_all(self.db.pool())
.await
.map_err(|e| {
ControlCenterError::from(infrastructure::internal_error(&format!(
"Failed to fetch TOTP devices: {}",
e
)))
})?;
let mut devices = Vec::new();
for row in rows {
let device_id: String = row.get("id");
if let Some(device) = self.get_totp_device(&device_id).await? {
devices.push(device);
}
}
Ok(devices)
}
async fn update_totp_device(&self, device: &TotpDevice) -> Result<()> {
let mut tx = self.db.pool().begin().await.map_err(|e| {
ControlCenterError::from(infrastructure::internal_error(&format!(
"Failed to start transaction: {}",
e
)))
})?;
// Update device
sqlx::query(
r#"
UPDATE mfa_totp_devices
SET last_used = ?, enabled = ?
WHERE id = ?
"#,
)
.bind(device.last_used.map(|dt| dt.to_rfc3339()))
.bind(device.enabled as i64)
.bind(&device.id)
.execute(&mut *tx)
.await
.map_err(|e| {
ControlCenterError::from(infrastructure::internal_error(&format!(
"Failed to update TOTP device: {}",
e
)))
})?;
// Update backup codes
sqlx::query("DELETE FROM mfa_backup_codes WHERE device_id = ?")
.bind(&device.id)
.execute(&mut *tx)
.await
.ok();
for backup_code in &device.backup_codes {
sqlx::query(
r#"
INSERT INTO mfa_backup_codes (device_id, code_hash, used, used_at)
VALUES (?, ?, ?, ?)
"#,
)
.bind(&device.id)
.bind(&backup_code.code_hash)
.bind(backup_code.used as i64)
.bind(backup_code.used_at.map(|dt| dt.to_rfc3339()))
.execute(&mut *tx)
.await
.map_err(|e| {
ControlCenterError::from(infrastructure::internal_error(&format!(
"Failed to insert backup code: {}",
e
)))
})?;
}
tx.commit().await.map_err(|e| {
ControlCenterError::from(infrastructure::internal_error(&format!(
"Failed to commit transaction: {}",
e
)))
})?;
Ok(())
}
async fn delete_totp_device(&self, device_id: &str) -> Result<()> {
sqlx::query("DELETE FROM mfa_totp_devices WHERE id = ?")
.bind(device_id)
.execute(self.db.pool())
.await
.map_err(|e| {
ControlCenterError::from(infrastructure::internal_error(&format!(
"Failed to delete TOTP device: {}",
e
)))
})?;
Ok(())
}
async fn create_webauthn_device(&self, device: &WebAuthnDevice) -> Result<()> {
let transports = serde_json::to_string(&device.transports).map_err(|e| {
ControlCenterError::from(infrastructure::internal_error(&format!(
"Failed to serialize transports: {}",
e
)))
})?;
sqlx::query(
r#"
INSERT INTO mfa_webauthn_devices
(id, user_id, credential_id, public_key, counter, device_name, created_at, last_used, enabled, attestation_type, transports)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"#,
)
.bind(&device.id)
.bind(&device.user_id)
.bind(&device.credential_id)
.bind(&device.public_key)
.bind(device.counter as i64)
.bind(&device.device_name)
.bind(device.created_at.to_rfc3339())
.bind(device.last_used.map(|dt| dt.to_rfc3339()))
.bind(device.enabled as i64)
.bind(&device.attestation_type)
.bind(transports)
.execute(self.db.pool())
.await
.map_err(|e| ControlCenterError::from(infrastructure::internal_error(&format!("Failed to insert WebAuthn device: {}", e))))?;
Ok(())
}
async fn get_webauthn_device(&self, device_id: &str) -> Result<Option<WebAuthnDevice>> {
let row = sqlx::query(
r#"
SELECT id, user_id, credential_id, public_key, counter, device_name, created_at, last_used, enabled, attestation_type, transports
FROM mfa_webauthn_devices
WHERE id = ?
"#,
)
.bind(device_id)
.fetch_optional(self.db.pool())
.await
.map_err(|e| ControlCenterError::from(infrastructure::internal_error(&format!("Failed to fetch WebAuthn device: {}", e))))?;
if let Some(row) = row {
let id: String = row.get("id");
let user_id: String = row.get("user_id");
let credential_id: Vec<u8> = row.get("credential_id");
let public_key: Vec<u8> = row.get("public_key");
let counter: i64 = row.get("counter");
let device_name: String = row.get("device_name");
let created_at: String = row.get("created_at");
let last_used: Option<String> = row.get("last_used");
let enabled: i64 = row.get("enabled");
let attestation_type: Option<String> = row.get("attestation_type");
let transports_json: String = row.get("transports");
let transports: Vec<String> = serde_json::from_str(&transports_json).map_err(|e| {
ControlCenterError::from(infrastructure::internal_error(&format!(
"Failed to deserialize transports: {}",
e
)))
})?;
let device = WebAuthnDevice {
id,
user_id,
credential_id,
public_key,
counter: counter as u32,
device_name,
created_at: DateTime::parse_from_rfc3339(&created_at)
.map_err(|e| {
ControlCenterError::from(infrastructure::internal_error(&format!(
"Invalid datetime: {}",
e
)))
})?
.with_timezone(&Utc),
last_used: last_used
.and_then(|s| DateTime::parse_from_rfc3339(&s).ok())
.map(|dt| dt.with_timezone(&Utc)),
enabled: enabled != 0,
attestation_type,
transports,
};
Ok(Some(device))
} else {
Ok(None)
}
}
async fn get_webauthn_devices_by_user(&self, user_id: &str) -> Result<Vec<WebAuthnDevice>> {
let rows = sqlx::query(
r#"
SELECT id
FROM mfa_webauthn_devices
WHERE user_id = ?
ORDER BY created_at DESC
"#,
)
.bind(user_id)
.fetch_all(self.db.pool())
.await
.map_err(|e| {
ControlCenterError::from(infrastructure::internal_error(&format!(
"Failed to fetch WebAuthn devices: {}",
e
)))
})?;
let mut devices = Vec::new();
for row in rows {
let device_id: String = row.get("id");
if let Some(device) = self.get_webauthn_device(&device_id).await? {
devices.push(device);
}
}
Ok(devices)
}
async fn update_webauthn_device(&self, device: &WebAuthnDevice) -> Result<()> {
let transports = serde_json::to_string(&device.transports).map_err(|e| {
ControlCenterError::from(infrastructure::internal_error(&format!(
"Failed to serialize transports: {}",
e
)))
})?;
sqlx::query(
r#"
UPDATE mfa_webauthn_devices
SET counter = ?, last_used = ?, enabled = ?, transports = ?
WHERE id = ?
"#,
)
.bind(device.counter as i64)
.bind(device.last_used.map(|dt| dt.to_rfc3339()))
.bind(device.enabled as i64)
.bind(transports)
.bind(&device.id)
.execute(self.db.pool())
.await
.map_err(|e| {
ControlCenterError::from(infrastructure::internal_error(&format!(
"Failed to update WebAuthn device: {}",
e
)))
})?;
Ok(())
}
async fn delete_webauthn_device(&self, device_id: &str) -> Result<()> {
sqlx::query("DELETE FROM mfa_webauthn_devices WHERE id = ?")
.bind(device_id)
.execute(self.db.pool())
.await
.map_err(|e| {
ControlCenterError::from(infrastructure::internal_error(&format!(
"Failed to delete WebAuthn device: {}",
e
)))
})?;
Ok(())
}
async fn get_all_devices_by_user(
&self,
user_id: &str,
) -> Result<(Vec<TotpDevice>, Vec<WebAuthnDevice>)> {
let totp_devices = self.get_totp_devices_by_user(user_id).await?;
let webauthn_devices = self.get_webauthn_devices_by_user(user_id).await?;
Ok((totp_devices, webauthn_devices))
}
async fn delete_all_devices_by_user(&self, user_id: &str) -> Result<()> {
let mut tx = self.db.pool().begin().await.map_err(|e| {
ControlCenterError::from(infrastructure::internal_error(&format!(
"Failed to start transaction: {}",
e
)))
})?;
sqlx::query("DELETE FROM mfa_totp_devices WHERE user_id = ?")
.bind(user_id)
.execute(&mut *tx)
.await
.ok();
sqlx::query("DELETE FROM mfa_webauthn_devices WHERE user_id = ?")
.bind(user_id)
.execute(&mut *tx)
.await
.ok();
tx.commit().await.map_err(|e| {
ControlCenterError::from(infrastructure::internal_error(&format!(
"Failed to commit transaction: {}",
e
)))
})?;
Ok(())
}
async fn user_has_mfa(&self, user_id: &str) -> Result<bool> {
let totp_count: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM mfa_totp_devices WHERE user_id = ? AND enabled = 1",
)
.bind(user_id)
.fetch_one(self.db.pool())
.await
.unwrap_or(0);
let webauthn_count: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM mfa_webauthn_devices WHERE user_id = ? AND enabled = 1",
)
.bind(user_id)
.fetch_one(self.db.pool())
.await
.unwrap_or(0);
Ok(totp_count > 0 || webauthn_count > 0)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::database::DatabaseConfig;
async fn create_test_storage() -> SqliteMfaStorage {
let config = DatabaseConfig {
url: ":memory:".to_string(),
max_connections: 5,
};
let db = Database::new(config).await.unwrap();
// Create users table for foreign key
sqlx::query(
r#"
CREATE TABLE users (
id TEXT PRIMARY KEY,
email TEXT NOT NULL UNIQUE,
name TEXT NOT NULL
)
"#,
)
.execute(db.pool())
.await
.unwrap();
SqliteMfaStorage::new(db).await.unwrap()
}
#[tokio::test]
async fn test_totp_device_crud() {
let storage = create_test_storage().await;
// Create test user
let user_id = "test_user";
sqlx::query("INSERT INTO users (id, email, name) VALUES (?, ?, ?)")
.bind(user_id)
.bind("test@example.com")
.bind("Test User")
.execute(storage.db.pool())
.await
.unwrap();
// Create device
let device = TotpDevice::new(user_id.to_string(), "SECRET123".to_string());
let device_id = device.id.clone();
storage.create_totp_device(&device).await.unwrap();
// Read device
let loaded = storage.get_totp_device(&device_id).await.unwrap();
assert!(loaded.is_some());
let loaded = loaded.unwrap();
assert_eq!(loaded.user_id, user_id);
assert_eq!(loaded.secret, "SECRET123");
// Update device
let mut updated = loaded.clone();
updated.enabled = true;
storage.update_totp_device(&updated).await.unwrap();
let reloaded = storage.get_totp_device(&device_id).await.unwrap().unwrap();
assert!(reloaded.enabled);
// Delete device
storage.delete_totp_device(&device_id).await.unwrap();
let deleted = storage.get_totp_device(&device_id).await.unwrap();
assert!(deleted.is_none());
}
#[tokio::test]
async fn test_webauthn_device_crud() {
let storage = create_test_storage().await;
// Create test user
let user_id = "test_user";
sqlx::query("INSERT INTO users (id, email, name) VALUES (?, ?, ?)")
.bind(user_id)
.bind("test@example.com")
.bind("Test User")
.execute(storage.db.pool())
.await
.unwrap();
// Create device
let device = WebAuthnDevice::new(
user_id.to_string(),
vec![1, 2, 3],
vec![4, 5, 6],
"YubiKey".to_string(),
);
let device_id = device.id.clone();
storage.create_webauthn_device(&device).await.unwrap();
// Read device
let loaded = storage.get_webauthn_device(&device_id).await.unwrap();
assert!(loaded.is_some());
let loaded = loaded.unwrap();
assert_eq!(loaded.user_id, user_id);
assert_eq!(loaded.device_name, "YubiKey");
// Delete device
storage.delete_webauthn_device(&device_id).await.unwrap();
let deleted = storage.get_webauthn_device(&device_id).await.unwrap();
assert!(deleted.is_none());
}
#[tokio::test]
async fn test_user_has_mfa() {
let storage = create_test_storage().await;
// Create test user
let user_id = "test_user";
sqlx::query("INSERT INTO users (id, email, name) VALUES (?, ?, ?)")
.bind(user_id)
.bind("test@example.com")
.bind("Test User")
.execute(storage.db.pool())
.await
.unwrap();
// Initially no MFA
assert!(!storage.user_has_mfa(user_id).await.unwrap());
// Add TOTP device
let mut device = TotpDevice::new(user_id.to_string(), "SECRET123".to_string());
device.enabled = true;
storage.create_totp_device(&device).await.unwrap();
// Now has MFA
assert!(storage.user_has_mfa(user_id).await.unwrap());
}
}