362 lines
11 KiB
Rust
362 lines
11 KiB
Rust
//! Database abstraction layer for supporting multiple database backends
|
|
//!
|
|
//! This module provides a unified interface for database operations that works
|
|
//! with both SQLite and PostgreSQL, allowing the application to be database-agnostic.
|
|
|
|
use anyhow::Result;
|
|
use chrono::{DateTime, Utc};
|
|
use sqlx::{PgPool, Row, SqlitePool};
|
|
use std::time::Duration;
|
|
use uuid::Uuid;
|
|
|
|
pub mod auth;
|
|
pub mod connection;
|
|
pub mod migrations;
|
|
pub mod rbac;
|
|
|
|
/// Database configuration
|
|
#[derive(Debug, Clone)]
|
|
pub struct DatabaseConfig {
|
|
pub url: String,
|
|
pub max_connections: u32,
|
|
pub min_connections: u32,
|
|
pub connect_timeout: Duration,
|
|
pub idle_timeout: Duration,
|
|
pub max_lifetime: Duration,
|
|
}
|
|
|
|
/// Database type enumeration
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub enum DatabaseType {
|
|
PostgreSQL,
|
|
SQLite,
|
|
}
|
|
|
|
/// Database connection pool abstraction
|
|
#[derive(Debug, Clone)]
|
|
pub enum DatabasePool {
|
|
PostgreSQL(PgPool),
|
|
SQLite(SqlitePool),
|
|
}
|
|
|
|
impl DatabasePool {
|
|
/// Create a new database pool from configuration
|
|
pub async fn new(config: &DatabaseConfig) -> Result<Self> {
|
|
let db_type = Self::detect_type(&config.url)?;
|
|
|
|
match db_type {
|
|
DatabaseType::PostgreSQL => {
|
|
let pool = sqlx::postgres::PgPoolOptions::new()
|
|
.max_connections(config.max_connections)
|
|
.min_connections(config.min_connections)
|
|
.acquire_timeout(config.connect_timeout)
|
|
.idle_timeout(config.idle_timeout)
|
|
.max_lifetime(config.max_lifetime)
|
|
.connect(&config.url)
|
|
.await?;
|
|
Ok(DatabasePool::PostgreSQL(pool))
|
|
}
|
|
DatabaseType::SQLite => {
|
|
// Ensure directory exists for SQLite
|
|
if let Some(path) = config.url.strip_prefix("sqlite:") {
|
|
if let Some(parent) = std::path::Path::new(path).parent() {
|
|
tokio::fs::create_dir_all(parent).await?;
|
|
}
|
|
}
|
|
|
|
let pool = sqlx::sqlite::SqlitePoolOptions::new()
|
|
.max_connections(config.max_connections)
|
|
.min_connections(config.min_connections)
|
|
.acquire_timeout(config.connect_timeout)
|
|
.idle_timeout(config.idle_timeout)
|
|
.max_lifetime(config.max_lifetime)
|
|
.connect(&config.url)
|
|
.await?;
|
|
Ok(DatabasePool::SQLite(pool))
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Detect database type from URL
|
|
pub fn detect_type(url: &str) -> Result<DatabaseType> {
|
|
if url.starts_with("postgres://") || url.starts_with("postgresql://") {
|
|
Ok(DatabaseType::PostgreSQL)
|
|
} else if url.starts_with("sqlite:") {
|
|
Ok(DatabaseType::SQLite)
|
|
} else {
|
|
Err(anyhow::anyhow!("Unsupported database URL: {}", url))
|
|
}
|
|
}
|
|
|
|
/// Get the database type
|
|
pub fn database_type(&self) -> DatabaseType {
|
|
match self {
|
|
DatabasePool::PostgreSQL(_) => DatabaseType::PostgreSQL,
|
|
DatabasePool::SQLite(_) => DatabaseType::SQLite,
|
|
}
|
|
}
|
|
|
|
/// Get PostgreSQL pool (if applicable)
|
|
pub fn as_postgres(&self) -> Option<&PgPool> {
|
|
match self {
|
|
DatabasePool::PostgreSQL(pool) => Some(pool),
|
|
_ => None,
|
|
}
|
|
}
|
|
|
|
/// Get SQLite pool (if applicable)
|
|
pub fn as_sqlite(&self) -> Option<&SqlitePool> {
|
|
match self {
|
|
DatabasePool::SQLite(pool) => Some(pool),
|
|
_ => None,
|
|
}
|
|
}
|
|
|
|
/// Close the database pool
|
|
pub async fn close(&self) {
|
|
match self {
|
|
DatabasePool::PostgreSQL(pool) => pool.close().await,
|
|
DatabasePool::SQLite(pool) => pool.close().await,
|
|
}
|
|
}
|
|
|
|
/// Check if the pool is closed
|
|
pub fn is_closed(&self) -> bool {
|
|
match self {
|
|
DatabasePool::PostgreSQL(pool) => pool.is_closed(),
|
|
DatabasePool::SQLite(pool) => pool.is_closed(),
|
|
}
|
|
}
|
|
|
|
/// Create a database connection from this pool
|
|
pub fn create_connection(&self) -> connection::DatabaseConnection {
|
|
connection::DatabaseConnection::from_pool(self)
|
|
}
|
|
}
|
|
|
|
/// Database row trait for abstracting over different database row types
|
|
pub trait DatabaseRow: Send + Sync {
|
|
fn get_string(&self, column: &str) -> Result<String>;
|
|
fn get_optional_string(&self, column: &str) -> Result<Option<String>>;
|
|
fn get_i32(&self, column: &str) -> Result<i32>;
|
|
fn get_optional_i32(&self, column: &str) -> Result<Option<i32>>;
|
|
fn get_i64(&self, column: &str) -> Result<i64>;
|
|
fn get_optional_i64(&self, column: &str) -> Result<Option<i64>>;
|
|
fn get_bool(&self, column: &str) -> Result<bool>;
|
|
fn get_optional_bool(&self, column: &str) -> Result<Option<bool>>;
|
|
fn get_bytes(&self, column: &str) -> Result<Vec<u8>>;
|
|
fn get_optional_bytes(&self, column: &str) -> Result<Option<Vec<u8>>>;
|
|
|
|
#[cfg(feature = "uuid")]
|
|
fn get_uuid(&self, column: &str) -> Result<Uuid>;
|
|
#[cfg(feature = "uuid")]
|
|
fn get_optional_uuid(&self, column: &str) -> Result<Option<Uuid>>;
|
|
|
|
fn get_datetime(&self, column: &str) -> Result<DateTime<Utc>>;
|
|
fn get_optional_datetime(&self, column: &str) -> Result<Option<DateTime<Utc>>>;
|
|
}
|
|
|
|
/// PostgreSQL row wrapper
|
|
#[derive(Debug)]
|
|
pub struct PostgresRow(pub sqlx::postgres::PgRow);
|
|
|
|
impl DatabaseRow for PostgresRow {
|
|
fn get_string(&self, column: &str) -> Result<String> {
|
|
Ok(self.0.try_get(column)?)
|
|
}
|
|
|
|
fn get_optional_string(&self, column: &str) -> Result<Option<String>> {
|
|
Ok(self.0.try_get(column)?)
|
|
}
|
|
|
|
fn get_i32(&self, column: &str) -> Result<i32> {
|
|
Ok(self.0.try_get(column)?)
|
|
}
|
|
|
|
fn get_optional_i32(&self, column: &str) -> Result<Option<i32>> {
|
|
Ok(self.0.try_get(column)?)
|
|
}
|
|
|
|
fn get_i64(&self, column: &str) -> Result<i64> {
|
|
Ok(self.0.try_get(column)?)
|
|
}
|
|
|
|
fn get_optional_i64(&self, column: &str) -> Result<Option<i64>> {
|
|
Ok(self.0.try_get(column)?)
|
|
}
|
|
|
|
fn get_bool(&self, column: &str) -> Result<bool> {
|
|
Ok(self.0.try_get(column)?)
|
|
}
|
|
|
|
fn get_optional_bool(&self, column: &str) -> Result<Option<bool>> {
|
|
Ok(self.0.try_get(column)?)
|
|
}
|
|
|
|
fn get_bytes(&self, column: &str) -> Result<Vec<u8>> {
|
|
Ok(self.0.try_get(column)?)
|
|
}
|
|
|
|
fn get_optional_bytes(&self, column: &str) -> Result<Option<Vec<u8>>> {
|
|
Ok(self.0.try_get(column)?)
|
|
}
|
|
|
|
#[cfg(feature = "uuid")]
|
|
fn get_uuid(&self, column: &str) -> Result<Uuid> {
|
|
Ok(self.0.try_get(column)?)
|
|
}
|
|
|
|
#[cfg(feature = "uuid")]
|
|
fn get_optional_uuid(&self, column: &str) -> Result<Option<Uuid>> {
|
|
Ok(self.0.try_get(column)?)
|
|
}
|
|
|
|
fn get_datetime(&self, column: &str) -> Result<DateTime<Utc>> {
|
|
Ok(self.0.try_get(column)?)
|
|
}
|
|
|
|
fn get_optional_datetime(&self, column: &str) -> Result<Option<DateTime<Utc>>> {
|
|
Ok(self.0.try_get(column)?)
|
|
}
|
|
}
|
|
|
|
/// SQLite row wrapper
|
|
pub struct SqliteRow(pub sqlx::sqlite::SqliteRow);
|
|
|
|
impl DatabaseRow for SqliteRow {
|
|
fn get_string(&self, column: &str) -> Result<String> {
|
|
Ok(self.0.try_get(column)?)
|
|
}
|
|
|
|
fn get_optional_string(&self, column: &str) -> Result<Option<String>> {
|
|
Ok(self.0.try_get(column)?)
|
|
}
|
|
|
|
fn get_i32(&self, column: &str) -> Result<i32> {
|
|
Ok(self.0.try_get(column)?)
|
|
}
|
|
|
|
fn get_optional_i32(&self, column: &str) -> Result<Option<i32>> {
|
|
Ok(self.0.try_get(column)?)
|
|
}
|
|
|
|
fn get_i64(&self, column: &str) -> Result<i64> {
|
|
Ok(self.0.try_get(column)?)
|
|
}
|
|
|
|
fn get_optional_i64(&self, column: &str) -> Result<Option<i64>> {
|
|
Ok(self.0.try_get(column)?)
|
|
}
|
|
|
|
fn get_bool(&self, column: &str) -> Result<bool> {
|
|
Ok(self.0.try_get(column)?)
|
|
}
|
|
|
|
fn get_optional_bool(&self, column: &str) -> Result<Option<bool>> {
|
|
Ok(self.0.try_get(column)?)
|
|
}
|
|
|
|
fn get_bytes(&self, column: &str) -> Result<Vec<u8>> {
|
|
Ok(self.0.try_get(column)?)
|
|
}
|
|
|
|
fn get_optional_bytes(&self, column: &str) -> Result<Option<Vec<u8>>> {
|
|
Ok(self.0.try_get(column)?)
|
|
}
|
|
|
|
#[cfg(feature = "uuid")]
|
|
fn get_uuid(&self, column: &str) -> Result<Uuid> {
|
|
// SQLite stores UUIDs as text
|
|
let uuid_str: String = self.0.try_get(column)?;
|
|
Ok(Uuid::parse_str(&uuid_str)?)
|
|
}
|
|
|
|
#[cfg(feature = "uuid")]
|
|
fn get_optional_uuid(&self, column: &str) -> Result<Option<Uuid>> {
|
|
let uuid_str: Option<String> = self.0.try_get(column)?;
|
|
match uuid_str {
|
|
Some(s) => Ok(Some(Uuid::parse_str(&s)?)),
|
|
None => Ok(None),
|
|
}
|
|
}
|
|
|
|
fn get_datetime(&self, column: &str) -> Result<DateTime<Utc>> {
|
|
// SQLite stores timestamps as text in ISO format
|
|
let timestamp_str: String = self.0.try_get(column)?;
|
|
Ok(DateTime::parse_from_rfc3339(×tamp_str)?.with_timezone(&Utc))
|
|
}
|
|
|
|
fn get_optional_datetime(&self, column: &str) -> Result<Option<DateTime<Utc>>> {
|
|
let timestamp_str: Option<String> = self.0.try_get(column)?;
|
|
match timestamp_str {
|
|
Some(s) => Ok(Some(DateTime::parse_from_rfc3339(&s)?.with_timezone(&Utc))),
|
|
None => Ok(None),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Database wrapper struct
|
|
#[derive(Debug, Clone)]
|
|
pub struct Database {
|
|
pool: DatabasePool,
|
|
}
|
|
|
|
impl Database {
|
|
/// Create a new database instance
|
|
pub fn new(pool: DatabasePool) -> Self {
|
|
Self { pool }
|
|
}
|
|
|
|
/// Get the database pool
|
|
pub fn pool(&self) -> &DatabasePool {
|
|
&self.pool
|
|
}
|
|
|
|
/// Clone the database pool
|
|
#[allow(dead_code)]
|
|
pub fn pool_clone(&self) -> DatabasePool {
|
|
self.pool.clone()
|
|
}
|
|
|
|
/// Create a database connection from this database
|
|
#[allow(dead_code)]
|
|
pub fn create_connection(&self) -> connection::DatabaseConnection {
|
|
self.pool.create_connection()
|
|
}
|
|
}
|
|
|
|
// Convenience methods for accessing underlying pools
|
|
impl Database {
|
|
/// Get PostgreSQL pool if available
|
|
pub fn as_pg_pool(&self) -> Option<&PgPool> {
|
|
self.pool.as_postgres()
|
|
}
|
|
|
|
/// Get SQLite pool if available
|
|
pub fn as_sqlite_pool(&self) -> Option<&SqlitePool> {
|
|
self.pool.as_sqlite()
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_database_type_detection() {
|
|
assert_eq!(
|
|
DatabasePool::detect_type("postgresql://user:pass@host/db").unwrap(),
|
|
DatabaseType::PostgreSQL
|
|
);
|
|
assert_eq!(
|
|
DatabasePool::detect_type("postgres://user:pass@host/db").unwrap(),
|
|
DatabaseType::PostgreSQL
|
|
);
|
|
assert_eq!(
|
|
DatabasePool::detect_type("sqlite:data/test.db").unwrap(),
|
|
DatabaseType::SQLite
|
|
);
|
|
assert!(DatabasePool::detect_type("mysql://user:pass@host/db").is_err());
|
|
}
|
|
}
|