use anyhow::Result; use axum::{ Router, extract::State, http::{HeaderMap, StatusCode}, middleware, response::Json, routing::{get, post}, }; use serde::{Deserialize, Serialize}; use serde_json::json; use std::sync::Arc; use tokio::time::{Duration as TokioDuration, interval}; use tower::ServiceBuilder; use tower_http::cors::CorsLayer; use tower_http::trace::TraceLayer; use crate::auth::{ConditionalRBACService, JwtService, auth_middleware}; use crate::config::{Config, features::FeatureConfig}; use crate::database::{Database, DatabaseConfig, DatabasePool}; use std::time::Duration as StdDuration; /// Main application state with optional RBAC support #[derive(Clone)] pub struct AppState { pub auth_repository: Arc, pub jwt_service: Arc, pub rbac_service: Arc, pub config: Arc, pub feature_config: Arc, } /// RBAC-compatible server that can run with or without RBAC features pub struct RBACCompatibleServer { pub app_state: AppState, pub config: Arc, } impl RBACCompatibleServer { /// Create a new server with optional RBAC features pub async fn new(config: Config) -> Result { let config = Arc::new(config); // Initialize database connection using new abstraction let database_config = DatabaseConfig { url: config.database.url.clone(), max_connections: config.database.max_connections, min_connections: config.database.min_connections, connect_timeout: StdDuration::from_secs(config.database.connect_timeout), idle_timeout: StdDuration::from_secs(config.database.idle_timeout), max_lifetime: StdDuration::from_secs(config.database.max_lifetime), }; let database_pool = DatabasePool::new(&database_config).await?; let database = Database::new(database_pool.clone()); // Initialize feature configuration let mut feature_config = FeatureConfig::from_env(); // Override with config file settings if config.features.rbac { feature_config.enable_rbac(); } if config.features.rbac_database_access { feature_config.rbac.database_access = true; } if config.features.rbac_file_access { feature_config.rbac.file_access = true; } if config.features.rbac_content_access { feature_config.rbac.content_access = true; } if config.features.rbac_categories { feature_config.rbac.categories = true; } if config.features.rbac_tags { feature_config.rbac.tags = true; } if config.features.rbac_caching { feature_config.rbac.caching = true; } if config.features.rbac_audit_logging { feature_config.rbac.audit_logging = true; } let feature_config = Arc::new(feature_config); // Initialize core authentication services let auth_repository = Arc::new(crate::database::auth::AuthRepository::new( database.create_connection(), )); let jwt_service = Arc::new( JwtService::new() .map_err(|e| anyhow::anyhow!("Failed to create JWT service: {}", e))?, ); // Initialize conditional RBAC service let rbac_config_path = if feature_config.is_rbac_feature_enabled("toml_config") { Some("config/rbac.toml") } else { None }; let rbac_service = Arc::new( ConditionalRBACService::new(&database_pool, feature_config.clone(), rbac_config_path) .await?, ); let app_state = AppState { auth_repository, jwt_service, rbac_service, config: config.clone(), feature_config, }; Ok(Self { app_state, config }) } /// Build the application router with conditional RBAC middleware pub fn build_router(&self) -> Router { let mut app = Router::new() // Health check endpoints .route("/health", get(health_check)) .route("/api/health", get(api_health_check)) .route("/api/features", get(feature_status)) // Authentication routes (always available) .route("/api/auth/login", post(auth_login)) .route("/api/auth/register", post(auth_register)) .route("/api/auth/refresh", post(auth_refresh)) .route("/api/auth/logout", post(auth_logout)) // User profile routes .route("/api/user/profile", get(get_user_profile)) .route("/api/user/profile", post(update_user_profile)) // Content routes (with conditional RBAC protection) .route("/api/content/:content_id", get(get_content)) .route("/api/content/:content_id", post(update_content)) // Database access routes (with conditional RBAC protection) .route("/api/database/:db_name", get(get_database_info)) .route("/api/database/:db_name/query", post(execute_database_query)) // File access routes (with conditional RBAC protection) .route("/api/files/*path", get(read_file)) .route("/api/files/*path", post(write_file)) // Admin routes (always require admin role, optionally enhanced with RBAC) .route("/api/admin/users", get(list_users)) .route("/api/admin/users/:user_id", get(get_user)) .route("/api/admin/users/:user_id", post(update_user)) .route( "/api/admin/users/:user_id", axum::routing::delete(delete_user), ); // Add RBAC-specific routes if enabled if let Some(rbac_routes) = self.app_state.rbac_service.create_rbac_routes() { app = app.merge(rbac_routes); } // Apply middleware layers app = app.layer( ServiceBuilder::new() .layer(TraceLayer::new_for_http()) .layer(CorsLayer::permissive()) // Always apply authentication middleware .layer(middleware::from_fn_with_state( ( self.app_state.jwt_service.clone(), self.app_state.auth_repository.clone(), ), auth_middleware, )), ); // Apply RBAC middleware conditionally app = app.apply_rbac_if_enabled(&self.app_state.rbac_service); // Apply specific RBAC middleware to protected routes app = self.apply_route_specific_rbac_middleware(app); app.with_state(self.app_state.clone()) } /// Apply route-specific RBAC middleware conditionally fn apply_route_specific_rbac_middleware( &self, mut router: Router, ) -> Router { // Database access protection if self .app_state .rbac_service .is_feature_enabled("database_access") { router = router .route( "/api/database/:db_name", get(get_database_info).layer(middleware::from_fn( self.app_state .rbac_service .database_access_middleware("*".to_string(), "read".to_string()) .unwrap(), )), ) .route( "/api/database/:db_name/query", post(execute_database_query).layer(middleware::from_fn( self.app_state .rbac_service .database_access_middleware("*".to_string(), "write".to_string()) .unwrap(), )), ); } // File access protection if self .app_state .rbac_service .is_feature_enabled("file_access") { router = router .route( "/api/files/*path", get(read_file).layer(middleware::from_fn( self.app_state .rbac_service .file_access_middleware("*".to_string(), "read".to_string()) .unwrap(), )), ) .route( "/api/files/*path", post(write_file).layer(middleware::from_fn( self.app_state .rbac_service .file_access_middleware("*".to_string(), "write".to_string()) .unwrap(), )), ); } // Content access protection if self .app_state .rbac_service .is_feature_enabled("content_access") { router = router .route( "/api/content/:content_id", get(get_content).layer(middleware::from_fn( self.app_state .rbac_service .content_access_middleware("*".to_string(), "read".to_string()) .unwrap(), )), ) .route( "/api/content/:content_id", post(update_content).layer(middleware::from_fn( self.app_state .rbac_service .content_access_middleware("*".to_string(), "write".to_string()) .unwrap(), )), ); } // Admin routes with category protection if self.app_state.rbac_service.is_feature_enabled("categories") { router = router .route( "/api/admin/users", get(list_users).layer(middleware::from_fn( self.app_state .rbac_service .category_access_middleware(vec!["admin".to_string()]) .unwrap(), )), ) .route( "/api/admin/users/:user_id", get(get_user).layer(middleware::from_fn( self.app_state .rbac_service .category_access_middleware(vec!["admin".to_string()]) .unwrap(), )), ); } router } /// Start the server pub async fn start(&self) -> Result<()> { let app = self.build_router(); let addr = self.config.server_address(); // Print startup information println!("๐Ÿš€ Starting Rustelo server..."); println!("๐Ÿ“ Address: {}", addr); println!("๐ŸŒ Environment: {:?}", self.config.server.environment); if self.app_state.rbac_service.is_enabled() { println!("๐Ÿ” RBAC System: Enabled"); let status = self.app_state.rbac_service.get_feature_status(); println!( " โ””โ”€ Features: {}", serde_json::to_string_pretty(&status["features"])? ); } else { println!("๐Ÿ”’ RBAC System: Disabled (using basic role-based auth)"); } // Start background tasks self.start_background_tasks().await; // Start the server let listener = tokio::net::TcpListener::bind(&addr).await?; println!("โœ… Server running on {}", addr); axum::serve(listener, app).await?; Ok(()) } /// Start background tasks conditionally async fn start_background_tasks(&self) { // Start RBAC background tasks if enabled self.app_state.rbac_service.start_background_tasks().await; // Start general maintenance tasks tokio::spawn(async { let mut interval = interval(Duration::from_secs(3600)); // 1 hour loop { interval.tick().await; println!("๐Ÿงน Running periodic maintenance tasks..."); // Add general cleanup tasks here } }); println!("๐Ÿš€ Background tasks started"); } } // ============================================================================= // Route Handlers // ============================================================================= /// Health check endpoint async fn health_check() -> Result, StatusCode> { Ok(Json(json!({ "status": "ok", "timestamp": chrono::Utc::now(), "service": "rustelo" }))) } /// API health check with database connectivity async fn api_health_check( State(state): State, ) -> Result, StatusCode> { // Check database connectivity let db_status = match sqlx::query("SELECT 1") .fetch_one(&state.auth_repository.pool) .await { Ok(_) => "connected", Err(_) => "disconnected", }; Ok(Json(json!({ "status": "ok", "timestamp": chrono::Utc::now(), "service": "rustelo", "database": db_status, "rbac_enabled": state.rbac_service.is_enabled() }))) } /// Feature status endpoint async fn feature_status( State(state): State, ) -> Result, StatusCode> { Ok(Json(json!({ "auth": state.feature_config.auth.enabled, "rbac": state.rbac_service.get_feature_status(), "content": state.feature_config.content.enabled, "security": { "csrf": state.feature_config.security.csrf, "rate_limiting": state.feature_config.security.rate_limiting }, "performance": { "caching": state.feature_config.performance.response_caching, "compression": state.feature_config.performance.compression } }))) } /// Authentication endpoints (simplified - integrate with full auth service) async fn auth_login( State(state): State, Json(credentials): Json, ) -> Result, StatusCode> { Ok(Json(json!({ "success": true, "message": "Login successful", "rbac_enabled": state.rbac_service.is_enabled() }))) } async fn auth_register( State(state): State, Json(user_data): Json, ) -> Result, StatusCode> { Ok(Json(json!({ "success": true, "message": "Registration successful" }))) } async fn auth_refresh( State(state): State, Json(request): Json, ) -> Result, StatusCode> { Ok(Json(json!({ "success": true, "message": "Token refreshed" }))) } async fn auth_logout(State(state): State) -> Result, StatusCode> { Ok(Json(json!({ "success": true, "message": "Logged out successfully" }))) } /// User profile endpoints async fn get_user_profile( State(state): State, ) -> Result, StatusCode> { Ok(Json(json!({ "profile": { "username": "example_user", "email": "user@example.com", "roles": ["user"] } }))) } async fn update_user_profile( State(state): State, Json(profile): Json, ) -> Result, StatusCode> { Ok(Json(json!({ "success": true, "message": "Profile updated successfully" }))) } /// Content endpoints (protected by conditional RBAC) async fn get_content( axum::extract::Path(content_id): axum::extract::Path, State(state): State, ) -> Result, StatusCode> { Ok(Json(json!({ "content_id": content_id, "title": "Example Content", "body": "This is example content...", "protection": if state.rbac_service.is_feature_enabled("content_access") { "RBAC Protected" } else { "Basic Role Protected" } }))) } async fn update_content( axum::extract::Path(content_id): axum::extract::Path, State(state): State, Json(content): Json, ) -> Result, StatusCode> { Ok(Json(json!({ "success": true, "content_id": content_id, "message": "Content updated successfully" }))) } /// Database endpoints (protected by conditional RBAC) async fn get_database_info( axum::extract::Path(db_name): axum::extract::Path, State(state): State, ) -> Result, StatusCode> { Ok(Json(json!({ "database": db_name, "status": "accessible", "protection": if state.rbac_service.is_feature_enabled("database_access") { "RBAC Protected" } else { "Basic Role Protected" } }))) } async fn execute_database_query( axum::extract::Path(db_name): axum::extract::Path, State(state): State, Json(query): Json, ) -> Result, StatusCode> { Ok(Json(json!({ "success": true, "database": db_name, "message": "Query executed successfully" }))) } /// File endpoints (protected by conditional RBAC) async fn read_file( axum::extract::Path(file_path): axum::extract::Path, State(state): State, ) -> Result, StatusCode> { Ok(Json(json!({ "file_path": file_path, "content": "File content here...", "protection": if state.rbac_service.is_feature_enabled("file_access") { "RBAC Protected" } else { "Basic Role Protected" } }))) } async fn write_file( axum::extract::Path(file_path): axum::extract::Path, State(state): State, Json(request): Json, ) -> Result, StatusCode> { Ok(Json(json!({ "success": true, "file_path": file_path, "message": "File written successfully" }))) } /// Admin endpoints (protected by conditional RBAC categories) async fn list_users(State(state): State) -> Result, StatusCode> { Ok(Json(json!({ "users": [], "protection": if state.rbac_service.is_feature_enabled("categories") { "RBAC Category Protected (admin)" } else { "Basic Admin Role Protected" } }))) } async fn get_user( axum::extract::Path(user_id): axum::extract::Path, State(state): State, ) -> Result, StatusCode> { Ok(Json(json!({ "user_id": user_id, "user": {}, "message": "User retrieved successfully" }))) } async fn update_user( axum::extract::Path(user_id): axum::extract::Path, State(state): State, Json(user_data): Json, ) -> Result, StatusCode> { Ok(Json(json!({ "success": true, "user_id": user_id, "message": "User updated successfully" }))) } async fn delete_user( axum::extract::Path(user_id): axum::extract::Path, State(state): State, ) -> Result, StatusCode> { Ok(Json(json!({ "success": true, "user_id": user_id, "message": "User deleted successfully" }))) } /// CLI entry point for the RBAC-compatible server #[tokio::main] async fn main() -> Result<()> { // Initialize tracing tracing_subscriber::fmt::init(); // Load configuration let config = Config::load().await?; // Validate configuration config.validate()?; // Create and start server let server = RBACCompatibleServer::new(config).await?; server.start().await?; Ok(()) } #[cfg(test)] mod tests { use super::*; #[tokio::test] async fn test_server_creation() { let config = Config::default(); // This would require a test database setup // For now, just test that the structure compiles assert!(true); } #[test] fn test_conditional_middleware() { let feature_config = Arc::new(FeatureConfig::default()); let rbac_service = ConditionalRBACService { rbac_service: None, rbac_repository: None, feature_config, }; // Test that middleware functions return None when RBAC is disabled assert!( rbac_service .database_access_middleware("test".to_string(), "read".to_string()) .is_none() ); assert!( rbac_service .file_access_middleware("test".to_string(), "read".to_string()) .is_none() ); assert!( rbac_service .category_access_middleware(vec!["admin".to_string()]) .is_none() ); } }