use anyhow::Result; use axum::{ Router, extract::State, http::{HeaderMap, StatusCode}, middleware, response::Json, routing::{get, post}, }; use serde_json::json; use std::sync::Arc; use tokio::time::{Duration as TokioDuration, interval}; use tower::ServiceBuilder; use tower_http::cors::CorsLayer; use tower_http::trace::TraceLayer; use crate::auth::{JwtService, RBACConfigLoader, RBACRepository, RBACService, auth_middleware}; use crate::database::{Database, DatabaseConfig, DatabasePool}; use crate::examples::rbac_integration::{ AppState, create_rbac_routes, initialize_rbac_system, setup_rbac_middleware, }; use std::time::Duration as StdDuration; /// Main server configuration with RBAC pub struct RBACServer { pub app_state: AppState, pub host: String, pub port: u16, } impl RBACServer { /// Create a new RBAC-enabled server pub async fn new( database_url: &str, rbac_config_path: &str, jwt_secret: &str, host: String, port: u16, ) -> Result { // Initialize database connection using new abstraction let database_config = DatabaseConfig { url: database_url.to_string(), max_connections: 20, min_connections: 1, connect_timeout: StdDuration::from_secs(30), idle_timeout: StdDuration::from_secs(600), max_lifetime: StdDuration::from_secs(3600), }; let database_pool = DatabasePool::new(&database_config).await?; let database = Database::new(database_pool.clone()); // Initialize repositories using new database abstraction let auth_repository = Arc::new(crate::database::auth::AuthRepository::new( database.create_connection(), )); let rbac_repository = Arc::new(RBACRepository::from_database_pool(&database_pool)); // Initialize JWT service let jwt_service = Arc::new( JwtService::new() .map_err(|e| anyhow::anyhow!("Failed to create JWT service: {}", e))?, ); // Initialize RBAC service let rbac_service = Arc::new(RBACService::new(rbac_repository.clone())); // Load RBAC configuration let config_loader = RBACConfigLoader::new(rbac_config_path); if !config_loader.config_exists() { println!("Creating default RBAC configuration..."); config_loader.create_default_config().await?; } // Load and save config to database let rbac_config = config_loader.load_from_file().await?; rbac_service .save_rbac_config("default", &rbac_config, Some("Server initialization")) .await?; println!( "RBAC system initialized with {} rules", rbac_config.rules.len() ); let app_state = AppState { rbac_service, rbac_repository, auth_repository, jwt_service, }; Ok(Self { app_state, host, port, }) } /// Build the application router with RBAC middleware pub fn build_router(&self) -> Router { let app = Router::new() // Health check endpoint .route("/health", get(health_check)) .route("/api/health", get(api_health_check)) // Authentication routes (no RBAC required) .route("/api/auth/login", post(auth_login)) .route("/api/auth/register", post(auth_register)) .route("/api/auth/refresh", post(auth_refresh)) .route("/api/auth/logout", post(auth_logout)) // User profile routes (basic auth required) .route("/api/user/profile", get(get_user_profile)) .route("/api/user/profile", post(update_user_profile)) // RBAC management routes (admin only) .route("/api/rbac/config", get(get_rbac_config)) .route("/api/rbac/config", post(update_rbac_config)) .route("/api/rbac/categories", get(list_categories)) .route("/api/rbac/categories", post(create_category)) .route("/api/rbac/tags", get(list_tags)) .route("/api/rbac/tags", post(create_tag)) .route( "/api/rbac/users/:user_id/categories", get(get_user_categories), ) .route( "/api/rbac/users/:user_id/categories", post(assign_user_category), ) .route("/api/rbac/users/:user_id/tags", get(get_user_tags)) .route("/api/rbac/users/:user_id/tags", post(assign_user_tag)) .route("/api/rbac/audit/:user_id", get(get_access_audit)) // Merge with RBAC-protected routes .merge(create_rbac_routes(self.app_state.clone())) // Apply global middleware .layer( ServiceBuilder::new() .layer(TraceLayer::new_for_http()) .layer(CorsLayer::permissive()) .layer(middleware::from_fn_with_state( ( self.app_state.jwt_service.clone(), self.app_state.auth_repository.clone(), ), auth_middleware, )) .layer(middleware::from_fn_with_state( self.app_state.rbac_service.clone(), crate::auth::rbac_middleware::rbac_middleware, )), ) .with_state(self.app_state.clone()); // Apply RBAC middleware to specific routes setup_rbac_middleware(app) } /// Start the server pub async fn start(&self) -> Result<()> { let app = self.build_router(); let addr = format!("{}:{}", self.host, self.port); println!("Starting RBAC-enabled server on {}", addr); // Start background tasks let cleanup_state = self.app_state.clone(); tokio::spawn(async move { let mut interval = interval(Duration::from_secs(300)); // 5 minutes loop { interval.tick().await; if let Err(e) = cleanup_state.rbac_service.cleanup_expired_cache().await { eprintln!("Error cleaning up expired cache: {}", e); } } }); // Start the server let listener = tokio::net::TcpListener::bind(&addr).await?; axum::serve(listener, app).await?; Ok(()) } } /// Health check endpoint async fn health_check() -> Result, StatusCode> { Ok(Json(json!({ "status": "ok", "timestamp": chrono::Utc::now(), "service": "rustelo-rbac" }))) } /// API health check with more details async fn api_health_check( State(state): State, ) -> Result, StatusCode> { // Check database connectivity let db_status = match state.rbac_repository.get_rbac_config("default").await { Ok(_) => "connected", Err(_) => "disconnected", }; Ok(Json(json!({ "status": "ok", "timestamp": chrono::Utc::now(), "service": "rustelo-rbac", "database": db_status, "rbac_enabled": true }))) } /// Authentication login endpoint async fn auth_login( State(state): State, Json(credentials): Json, ) -> Result, StatusCode> { // This is a simplified example - in a real implementation, // you'd use the full AuthService Ok(Json(json!({ "success": true, "message": "Login endpoint - implement with AuthService", "email": credentials.email }))) } /// Authentication register endpoint async fn auth_register( State(state): State, Json(user_data): Json, ) -> Result, StatusCode> { Ok(Json(json!({ "success": true, "message": "Register endpoint - implement with AuthService", "email": user_data.email }))) } /// Token refresh endpoint async fn auth_refresh( State(state): State, Json(request): Json, ) -> Result, StatusCode> { Ok(Json(json!({ "success": true, "message": "Refresh endpoint - implement with AuthService" }))) } /// Logout endpoint async fn auth_logout(State(state): State) -> Result, StatusCode> { Ok(Json(json!({ "success": true, "message": "Logged out successfully" }))) } /// Get user profile async fn get_user_profile( State(state): State, ) -> Result, StatusCode> { Ok(Json(json!({ "success": true, "message": "User profile endpoint - implement with AuthService" }))) } /// Update user profile async fn update_user_profile( State(state): State, Json(profile): Json, ) -> Result, StatusCode> { Ok(Json(json!({ "success": true, "message": "Profile update endpoint - implement with AuthService" }))) } /// Get RBAC configuration async fn get_rbac_config( State(state): State, ) -> Result, StatusCode> { match state.rbac_service.get_rbac_config("default").await { Ok(config) => Ok(Json(json!({ "success": true, "config": config }))), Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR), } } /// Update RBAC configuration async fn update_rbac_config( State(state): State, Json(config): Json, ) -> Result, StatusCode> { match state .rbac_service .save_rbac_config("default", &config, Some("Updated via API")) .await { Ok(_) => Ok(Json(json!({ "success": true, "message": "RBAC configuration updated" }))), Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR), } } /// List all categories async fn list_categories( State(state): State, ) -> Result, StatusCode> { Ok(Json(json!({ "success": true, "categories": ["admin", "editor", "viewer", "finance", "hr", "it"], "message": "Categories retrieved successfully" }))) } /// Create a new category async fn create_category( State(state): State, Json(category): Json, ) -> Result, StatusCode> { let name = category .get("name") .and_then(|n| n.as_str()) .unwrap_or("unnamed"); let description = category.get("description").and_then(|d| d.as_str()); match state .rbac_repository .create_category(name, description, None) .await { Ok(created_category) => Ok(Json(json!({ "success": true, "category": { "id": created_category.id, "name": created_category.name, "description": created_category.description }, "message": "Category created successfully" }))), Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR), } } /// List all tags async fn list_tags(State(state): State) -> Result, StatusCode> { Ok(Json(json!({ "success": true, "tags": ["sensitive", "public", "internal", "confidential", "restricted", "temporary"], "message": "Tags retrieved successfully" }))) } /// Create a new tag async fn create_tag( State(state): State, Json(tag): Json, ) -> Result, StatusCode> { let name = tag .get("name") .and_then(|n| n.as_str()) .unwrap_or("unnamed"); let description = tag.get("description").and_then(|d| d.as_str()); let color = tag.get("color").and_then(|c| c.as_str()); match state .rbac_repository .create_tag(name, description, color) .await { Ok(created_tag) => Ok(Json(json!({ "success": true, "tag": { "id": created_tag.id, "name": created_tag.name, "description": created_tag.description, "color": created_tag.color }, "message": "Tag created successfully" }))), Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR), } } /// Get user categories async fn get_user_categories( axum::extract::Path(user_id): axum::extract::Path, State(state): State, ) -> Result, StatusCode> { match state.rbac_repository.get_user_categories(user_id).await { Ok(categories) => Ok(Json(json!({ "success": true, "user_id": user_id, "categories": categories }))), Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR), } } /// Assign category to user async fn assign_user_category( axum::extract::Path(user_id): axum::extract::Path, State(state): State, Json(request): Json, ) -> Result, StatusCode> { let category_name = request .get("category") .and_then(|c| c.as_str()) .unwrap_or(""); match state .rbac_service .assign_category_to_user(user_id, category_name, None, None) .await { Ok(_) => Ok(Json(json!({ "success": true, "message": "Category assigned successfully" }))), Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR), } } /// Get user tags async fn get_user_tags( axum::extract::Path(user_id): axum::extract::Path, State(state): State, ) -> Result, StatusCode> { match state.rbac_repository.get_user_tags(user_id).await { Ok(tags) => Ok(Json(json!({ "success": true, "user_id": user_id, "tags": tags }))), Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR), } } /// Assign tag to user async fn assign_user_tag( axum::extract::Path(user_id): axum::extract::Path, State(state): State, Json(request): Json, ) -> Result, StatusCode> { let tag_name = request.get("tag").and_then(|t| t.as_str()).unwrap_or(""); match state .rbac_service .assign_tag_to_user(user_id, tag_name, None, None) .await { Ok(_) => Ok(Json(json!({ "success": true, "message": "Tag assigned successfully" }))), Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR), } } /// Get access audit for user async fn get_access_audit( axum::extract::Path(user_id): axum::extract::Path, State(state): State, ) -> Result, StatusCode> { match state .rbac_service .get_user_access_history(user_id, 100) .await { Ok(history) => Ok(Json(json!({ "success": true, "user_id": user_id, "audit_log": history }))), Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR), } } /// CLI entry point for RBAC server #[tokio::main] async fn main() -> Result<()> { // Initialize tracing tracing_subscriber::fmt::init(); // Load configuration from environment let database_url = std::env::var("DATABASE_URL") .unwrap_or_else(|_| "postgres://dev:dev@localhost:5432/rustelo_dev".to_string()); let rbac_config_path = std::env::var("RBAC_CONFIG_PATH").unwrap_or_else(|_| "config/rbac.toml".to_string()); let jwt_secret = std::env::var("JWT_SECRET") .unwrap_or_else(|_| "your-super-secret-jwt-key-change-this-in-production".to_string()); let host = std::env::var("SERVER_HOST").unwrap_or_else(|_| "127.0.0.1".to_string()); let port = std::env::var("SERVER_PORT") .unwrap_or_else(|_| "3030".to_string()) .parse::() .unwrap_or(3030); println!("Initializing RBAC server..."); println!("Database URL: {}", database_url); println!("RBAC Config: {}", rbac_config_path); println!("Server: {}:{}", host, port); // Create and start server let server = RBACServer::new(&database_url, &rbac_config_path, &jwt_secret, host, port).await?; server.start().await?; Ok(()) } #[cfg(test)] mod tests { use super::*; use axum::body::Body; use axum::http::{Method, Request, StatusCode}; use tower::ServiceExt; #[tokio::test] async fn test_health_check() { let response = health_check().await.unwrap(); assert!(response.0.get("status").unwrap().as_str().unwrap() == "ok"); } #[tokio::test] async fn test_server_creation() { // This would require a test database setup // For now, just test that the structure compiles assert!(true); } }