2025-07-07 23:05:19 +01:00

362 lines
11 KiB
Rust

//! Database abstraction layer for supporting multiple database backends
//!
//! This module provides a unified interface for database operations that works
//! with both SQLite and PostgreSQL, allowing the application to be database-agnostic.
use anyhow::Result;
use chrono::{DateTime, Utc};
use sqlx::{PgPool, Row, SqlitePool};
use std::time::Duration;
use uuid::Uuid;
pub mod auth;
pub mod connection;
pub mod migrations;
pub mod rbac;
/// Database configuration
#[derive(Debug, Clone)]
pub struct DatabaseConfig {
pub url: String,
pub max_connections: u32,
pub min_connections: u32,
pub connect_timeout: Duration,
pub idle_timeout: Duration,
pub max_lifetime: Duration,
}
/// Database type enumeration
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DatabaseType {
PostgreSQL,
SQLite,
}
/// Database connection pool abstraction
#[derive(Debug, Clone)]
pub enum DatabasePool {
PostgreSQL(PgPool),
SQLite(SqlitePool),
}
impl DatabasePool {
/// Create a new database pool from configuration
pub async fn new(config: &DatabaseConfig) -> Result<Self> {
let db_type = Self::detect_type(&config.url)?;
match db_type {
DatabaseType::PostgreSQL => {
let pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(config.max_connections)
.min_connections(config.min_connections)
.acquire_timeout(config.connect_timeout)
.idle_timeout(config.idle_timeout)
.max_lifetime(config.max_lifetime)
.connect(&config.url)
.await?;
Ok(DatabasePool::PostgreSQL(pool))
}
DatabaseType::SQLite => {
// Ensure directory exists for SQLite
if let Some(path) = config.url.strip_prefix("sqlite:") {
if let Some(parent) = std::path::Path::new(path).parent() {
tokio::fs::create_dir_all(parent).await?;
}
}
let pool = sqlx::sqlite::SqlitePoolOptions::new()
.max_connections(config.max_connections)
.min_connections(config.min_connections)
.acquire_timeout(config.connect_timeout)
.idle_timeout(config.idle_timeout)
.max_lifetime(config.max_lifetime)
.connect(&config.url)
.await?;
Ok(DatabasePool::SQLite(pool))
}
}
}
/// Detect database type from URL
pub fn detect_type(url: &str) -> Result<DatabaseType> {
if url.starts_with("postgres://") || url.starts_with("postgresql://") {
Ok(DatabaseType::PostgreSQL)
} else if url.starts_with("sqlite:") {
Ok(DatabaseType::SQLite)
} else {
Err(anyhow::anyhow!("Unsupported database URL: {}", url))
}
}
/// Get the database type
pub fn database_type(&self) -> DatabaseType {
match self {
DatabasePool::PostgreSQL(_) => DatabaseType::PostgreSQL,
DatabasePool::SQLite(_) => DatabaseType::SQLite,
}
}
/// Get PostgreSQL pool (if applicable)
pub fn as_postgres(&self) -> Option<&PgPool> {
match self {
DatabasePool::PostgreSQL(pool) => Some(pool),
_ => None,
}
}
/// Get SQLite pool (if applicable)
pub fn as_sqlite(&self) -> Option<&SqlitePool> {
match self {
DatabasePool::SQLite(pool) => Some(pool),
_ => None,
}
}
/// Close the database pool
pub async fn close(&self) {
match self {
DatabasePool::PostgreSQL(pool) => pool.close().await,
DatabasePool::SQLite(pool) => pool.close().await,
}
}
/// Check if the pool is closed
pub fn is_closed(&self) -> bool {
match self {
DatabasePool::PostgreSQL(pool) => pool.is_closed(),
DatabasePool::SQLite(pool) => pool.is_closed(),
}
}
/// Create a database connection from this pool
pub fn create_connection(&self) -> connection::DatabaseConnection {
connection::DatabaseConnection::from_pool(self)
}
}
/// Database row trait for abstracting over different database row types
pub trait DatabaseRow: Send + Sync {
fn get_string(&self, column: &str) -> Result<String>;
fn get_optional_string(&self, column: &str) -> Result<Option<String>>;
fn get_i32(&self, column: &str) -> Result<i32>;
fn get_optional_i32(&self, column: &str) -> Result<Option<i32>>;
fn get_i64(&self, column: &str) -> Result<i64>;
fn get_optional_i64(&self, column: &str) -> Result<Option<i64>>;
fn get_bool(&self, column: &str) -> Result<bool>;
fn get_optional_bool(&self, column: &str) -> Result<Option<bool>>;
fn get_bytes(&self, column: &str) -> Result<Vec<u8>>;
fn get_optional_bytes(&self, column: &str) -> Result<Option<Vec<u8>>>;
#[cfg(feature = "uuid")]
fn get_uuid(&self, column: &str) -> Result<Uuid>;
#[cfg(feature = "uuid")]
fn get_optional_uuid(&self, column: &str) -> Result<Option<Uuid>>;
fn get_datetime(&self, column: &str) -> Result<DateTime<Utc>>;
fn get_optional_datetime(&self, column: &str) -> Result<Option<DateTime<Utc>>>;
}
/// PostgreSQL row wrapper
#[derive(Debug)]
pub struct PostgresRow(pub sqlx::postgres::PgRow);
impl DatabaseRow for PostgresRow {
fn get_string(&self, column: &str) -> Result<String> {
Ok(self.0.try_get(column)?)
}
fn get_optional_string(&self, column: &str) -> Result<Option<String>> {
Ok(self.0.try_get(column)?)
}
fn get_i32(&self, column: &str) -> Result<i32> {
Ok(self.0.try_get(column)?)
}
fn get_optional_i32(&self, column: &str) -> Result<Option<i32>> {
Ok(self.0.try_get(column)?)
}
fn get_i64(&self, column: &str) -> Result<i64> {
Ok(self.0.try_get(column)?)
}
fn get_optional_i64(&self, column: &str) -> Result<Option<i64>> {
Ok(self.0.try_get(column)?)
}
fn get_bool(&self, column: &str) -> Result<bool> {
Ok(self.0.try_get(column)?)
}
fn get_optional_bool(&self, column: &str) -> Result<Option<bool>> {
Ok(self.0.try_get(column)?)
}
fn get_bytes(&self, column: &str) -> Result<Vec<u8>> {
Ok(self.0.try_get(column)?)
}
fn get_optional_bytes(&self, column: &str) -> Result<Option<Vec<u8>>> {
Ok(self.0.try_get(column)?)
}
#[cfg(feature = "uuid")]
fn get_uuid(&self, column: &str) -> Result<Uuid> {
Ok(self.0.try_get(column)?)
}
#[cfg(feature = "uuid")]
fn get_optional_uuid(&self, column: &str) -> Result<Option<Uuid>> {
Ok(self.0.try_get(column)?)
}
fn get_datetime(&self, column: &str) -> Result<DateTime<Utc>> {
Ok(self.0.try_get(column)?)
}
fn get_optional_datetime(&self, column: &str) -> Result<Option<DateTime<Utc>>> {
Ok(self.0.try_get(column)?)
}
}
/// SQLite row wrapper
pub struct SqliteRow(pub sqlx::sqlite::SqliteRow);
impl DatabaseRow for SqliteRow {
fn get_string(&self, column: &str) -> Result<String> {
Ok(self.0.try_get(column)?)
}
fn get_optional_string(&self, column: &str) -> Result<Option<String>> {
Ok(self.0.try_get(column)?)
}
fn get_i32(&self, column: &str) -> Result<i32> {
Ok(self.0.try_get(column)?)
}
fn get_optional_i32(&self, column: &str) -> Result<Option<i32>> {
Ok(self.0.try_get(column)?)
}
fn get_i64(&self, column: &str) -> Result<i64> {
Ok(self.0.try_get(column)?)
}
fn get_optional_i64(&self, column: &str) -> Result<Option<i64>> {
Ok(self.0.try_get(column)?)
}
fn get_bool(&self, column: &str) -> Result<bool> {
Ok(self.0.try_get(column)?)
}
fn get_optional_bool(&self, column: &str) -> Result<Option<bool>> {
Ok(self.0.try_get(column)?)
}
fn get_bytes(&self, column: &str) -> Result<Vec<u8>> {
Ok(self.0.try_get(column)?)
}
fn get_optional_bytes(&self, column: &str) -> Result<Option<Vec<u8>>> {
Ok(self.0.try_get(column)?)
}
#[cfg(feature = "uuid")]
fn get_uuid(&self, column: &str) -> Result<Uuid> {
// SQLite stores UUIDs as text
let uuid_str: String = self.0.try_get(column)?;
Ok(Uuid::parse_str(&uuid_str)?)
}
#[cfg(feature = "uuid")]
fn get_optional_uuid(&self, column: &str) -> Result<Option<Uuid>> {
let uuid_str: Option<String> = self.0.try_get(column)?;
match uuid_str {
Some(s) => Ok(Some(Uuid::parse_str(&s)?)),
None => Ok(None),
}
}
fn get_datetime(&self, column: &str) -> Result<DateTime<Utc>> {
// SQLite stores timestamps as text in ISO format
let timestamp_str: String = self.0.try_get(column)?;
Ok(DateTime::parse_from_rfc3339(&timestamp_str)?.with_timezone(&Utc))
}
fn get_optional_datetime(&self, column: &str) -> Result<Option<DateTime<Utc>>> {
let timestamp_str: Option<String> = self.0.try_get(column)?;
match timestamp_str {
Some(s) => Ok(Some(DateTime::parse_from_rfc3339(&s)?.with_timezone(&Utc))),
None => Ok(None),
}
}
}
/// Database wrapper struct
#[derive(Debug, Clone)]
pub struct Database {
pool: DatabasePool,
}
impl Database {
/// Create a new database instance
pub fn new(pool: DatabasePool) -> Self {
Self { pool }
}
/// Get the database pool
pub fn pool(&self) -> &DatabasePool {
&self.pool
}
/// Clone the database pool
#[allow(dead_code)]
pub fn pool_clone(&self) -> DatabasePool {
self.pool.clone()
}
/// Create a database connection from this database
#[allow(dead_code)]
pub fn create_connection(&self) -> connection::DatabaseConnection {
self.pool.create_connection()
}
}
// Convenience methods for accessing underlying pools
impl Database {
/// Get PostgreSQL pool if available
pub fn as_pg_pool(&self) -> Option<&PgPool> {
self.pool.as_postgres()
}
/// Get SQLite pool if available
pub fn as_sqlite_pool(&self) -> Option<&SqlitePool> {
self.pool.as_sqlite()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_database_type_detection() {
assert_eq!(
DatabasePool::detect_type("postgresql://user:pass@host/db").unwrap(),
DatabaseType::PostgreSQL
);
assert_eq!(
DatabasePool::detect_type("postgres://user:pass@host/db").unwrap(),
DatabaseType::PostgreSQL
);
assert_eq!(
DatabasePool::detect_type("sqlite:data/test.db").unwrap(),
DatabaseType::SQLite
);
assert!(DatabasePool::detect_type("mysql://user:pass@host/db").is_err());
}
}