Rustelo/server/src/rbac_server.rs
2025-07-07 23:05:19 +01:00

652 lines
21 KiB
Rust

use anyhow::Result;
use axum::{
Router,
extract::State,
http::{HeaderMap, StatusCode},
middleware,
response::Json,
routing::{get, post},
};
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::sync::Arc;
use tokio::time::{Duration as TokioDuration, interval};
use tower::ServiceBuilder;
use tower_http::cors::CorsLayer;
use tower_http::trace::TraceLayer;
use crate::auth::{ConditionalRBACService, JwtService, auth_middleware};
use crate::config::{Config, features::FeatureConfig};
use crate::database::{Database, DatabaseConfig, DatabasePool};
use std::time::Duration as StdDuration;
/// Main application state with optional RBAC support
#[derive(Clone)]
pub struct AppState {
pub auth_repository: Arc<crate::database::auth::AuthRepository>,
pub jwt_service: Arc<JwtService>,
pub rbac_service: Arc<ConditionalRBACService>,
pub config: Arc<Config>,
pub feature_config: Arc<FeatureConfig>,
}
/// RBAC-compatible server that can run with or without RBAC features
pub struct RBACCompatibleServer {
pub app_state: AppState,
pub config: Arc<Config>,
}
impl RBACCompatibleServer {
/// Create a new server with optional RBAC features
pub async fn new(config: Config) -> Result<Self> {
let config = Arc::new(config);
// Initialize database connection using new abstraction
let database_config = DatabaseConfig {
url: config.database.url.clone(),
max_connections: config.database.max_connections,
min_connections: config.database.min_connections,
connect_timeout: StdDuration::from_secs(config.database.connect_timeout),
idle_timeout: StdDuration::from_secs(config.database.idle_timeout),
max_lifetime: StdDuration::from_secs(config.database.max_lifetime),
};
let database_pool = DatabasePool::new(&database_config).await?;
let database = Database::new(database_pool.clone());
// Initialize feature configuration
let mut feature_config = FeatureConfig::from_env();
// Override with config file settings
if config.features.rbac {
feature_config.enable_rbac();
}
if config.features.rbac_database_access {
feature_config.rbac.database_access = true;
}
if config.features.rbac_file_access {
feature_config.rbac.file_access = true;
}
if config.features.rbac_content_access {
feature_config.rbac.content_access = true;
}
if config.features.rbac_categories {
feature_config.rbac.categories = true;
}
if config.features.rbac_tags {
feature_config.rbac.tags = true;
}
if config.features.rbac_caching {
feature_config.rbac.caching = true;
}
if config.features.rbac_audit_logging {
feature_config.rbac.audit_logging = true;
}
let feature_config = Arc::new(feature_config);
// Initialize core authentication services
let auth_repository = Arc::new(crate::database::auth::AuthRepository::new(
database.create_connection(),
));
let jwt_service = Arc::new(
JwtService::new()
.map_err(|e| anyhow::anyhow!("Failed to create JWT service: {}", e))?,
);
// Initialize conditional RBAC service
let rbac_config_path = if feature_config.is_rbac_feature_enabled("toml_config") {
Some("config/rbac.toml")
} else {
None
};
let rbac_service = Arc::new(
ConditionalRBACService::new(&database_pool, feature_config.clone(), rbac_config_path)
.await?,
);
let app_state = AppState {
auth_repository,
jwt_service,
rbac_service,
config: config.clone(),
feature_config,
};
Ok(Self { app_state, config })
}
/// Build the application router with conditional RBAC middleware
pub fn build_router(&self) -> Router {
let mut app = Router::new()
// Health check endpoints
.route("/health", get(health_check))
.route("/api/health", get(api_health_check))
.route("/api/features", get(feature_status))
// Authentication routes (always available)
.route("/api/auth/login", post(auth_login))
.route("/api/auth/register", post(auth_register))
.route("/api/auth/refresh", post(auth_refresh))
.route("/api/auth/logout", post(auth_logout))
// User profile routes
.route("/api/user/profile", get(get_user_profile))
.route("/api/user/profile", post(update_user_profile))
// Content routes (with conditional RBAC protection)
.route("/api/content/:content_id", get(get_content))
.route("/api/content/:content_id", post(update_content))
// Database access routes (with conditional RBAC protection)
.route("/api/database/:db_name", get(get_database_info))
.route("/api/database/:db_name/query", post(execute_database_query))
// File access routes (with conditional RBAC protection)
.route("/api/files/*path", get(read_file))
.route("/api/files/*path", post(write_file))
// Admin routes (always require admin role, optionally enhanced with RBAC)
.route("/api/admin/users", get(list_users))
.route("/api/admin/users/:user_id", get(get_user))
.route("/api/admin/users/:user_id", post(update_user))
.route(
"/api/admin/users/:user_id",
axum::routing::delete(delete_user),
);
// Add RBAC-specific routes if enabled
if let Some(rbac_routes) = self.app_state.rbac_service.create_rbac_routes() {
app = app.merge(rbac_routes);
}
// Apply middleware layers
app = app.layer(
ServiceBuilder::new()
.layer(TraceLayer::new_for_http())
.layer(CorsLayer::permissive())
// Always apply authentication middleware
.layer(middleware::from_fn_with_state(
(
self.app_state.jwt_service.clone(),
self.app_state.auth_repository.clone(),
),
auth_middleware,
)),
);
// Apply RBAC middleware conditionally
app = app.apply_rbac_if_enabled(&self.app_state.rbac_service);
// Apply specific RBAC middleware to protected routes
app = self.apply_route_specific_rbac_middleware(app);
app.with_state(self.app_state.clone())
}
/// Apply route-specific RBAC middleware conditionally
fn apply_route_specific_rbac_middleware(
&self,
mut router: Router<AppState>,
) -> Router<AppState> {
// Database access protection
if self
.app_state
.rbac_service
.is_feature_enabled("database_access")
{
router = router
.route(
"/api/database/:db_name",
get(get_database_info).layer(middleware::from_fn(
self.app_state
.rbac_service
.database_access_middleware("*".to_string(), "read".to_string())
.unwrap(),
)),
)
.route(
"/api/database/:db_name/query",
post(execute_database_query).layer(middleware::from_fn(
self.app_state
.rbac_service
.database_access_middleware("*".to_string(), "write".to_string())
.unwrap(),
)),
);
}
// File access protection
if self
.app_state
.rbac_service
.is_feature_enabled("file_access")
{
router = router
.route(
"/api/files/*path",
get(read_file).layer(middleware::from_fn(
self.app_state
.rbac_service
.file_access_middleware("*".to_string(), "read".to_string())
.unwrap(),
)),
)
.route(
"/api/files/*path",
post(write_file).layer(middleware::from_fn(
self.app_state
.rbac_service
.file_access_middleware("*".to_string(), "write".to_string())
.unwrap(),
)),
);
}
// Content access protection
if self
.app_state
.rbac_service
.is_feature_enabled("content_access")
{
router = router
.route(
"/api/content/:content_id",
get(get_content).layer(middleware::from_fn(
self.app_state
.rbac_service
.content_access_middleware("*".to_string(), "read".to_string())
.unwrap(),
)),
)
.route(
"/api/content/:content_id",
post(update_content).layer(middleware::from_fn(
self.app_state
.rbac_service
.content_access_middleware("*".to_string(), "write".to_string())
.unwrap(),
)),
);
}
// Admin routes with category protection
if self.app_state.rbac_service.is_feature_enabled("categories") {
router = router
.route(
"/api/admin/users",
get(list_users).layer(middleware::from_fn(
self.app_state
.rbac_service
.category_access_middleware(vec!["admin".to_string()])
.unwrap(),
)),
)
.route(
"/api/admin/users/:user_id",
get(get_user).layer(middleware::from_fn(
self.app_state
.rbac_service
.category_access_middleware(vec!["admin".to_string()])
.unwrap(),
)),
);
}
router
}
/// Start the server
pub async fn start(&self) -> Result<()> {
let app = self.build_router();
let addr = self.config.server_address();
// Print startup information
println!("🚀 Starting Rustelo server...");
println!("📍 Address: {}", addr);
println!("🌍 Environment: {:?}", self.config.server.environment);
if self.app_state.rbac_service.is_enabled() {
println!("🔐 RBAC System: Enabled");
let status = self.app_state.rbac_service.get_feature_status();
println!(
" └─ Features: {}",
serde_json::to_string_pretty(&status["features"])?
);
} else {
println!("🔒 RBAC System: Disabled (using basic role-based auth)");
}
// Start background tasks
self.start_background_tasks().await;
// Start the server
let listener = tokio::net::TcpListener::bind(&addr).await?;
println!("✅ Server running on {}", addr);
axum::serve(listener, app).await?;
Ok(())
}
/// Start background tasks conditionally
async fn start_background_tasks(&self) {
// Start RBAC background tasks if enabled
self.app_state.rbac_service.start_background_tasks().await;
// Start general maintenance tasks
tokio::spawn(async {
let mut interval = interval(Duration::from_secs(3600)); // 1 hour
loop {
interval.tick().await;
println!("🧹 Running periodic maintenance tasks...");
// Add general cleanup tasks here
}
});
println!("🚀 Background tasks started");
}
}
// =============================================================================
// Route Handlers
// =============================================================================
/// Health check endpoint
async fn health_check() -> Result<Json<serde_json::Value>, StatusCode> {
Ok(Json(json!({
"status": "ok",
"timestamp": chrono::Utc::now(),
"service": "rustelo"
})))
}
/// API health check with database connectivity
async fn api_health_check(
State(state): State<AppState>,
) -> Result<Json<serde_json::Value>, StatusCode> {
// Check database connectivity
let db_status = match sqlx::query("SELECT 1")
.fetch_one(&state.auth_repository.pool)
.await
{
Ok(_) => "connected",
Err(_) => "disconnected",
};
Ok(Json(json!({
"status": "ok",
"timestamp": chrono::Utc::now(),
"service": "rustelo",
"database": db_status,
"rbac_enabled": state.rbac_service.is_enabled()
})))
}
/// Feature status endpoint
async fn feature_status(
State(state): State<AppState>,
) -> Result<Json<serde_json::Value>, StatusCode> {
Ok(Json(json!({
"auth": state.feature_config.auth.enabled,
"rbac": state.rbac_service.get_feature_status(),
"content": state.feature_config.content.enabled,
"security": {
"csrf": state.feature_config.security.csrf,
"rate_limiting": state.feature_config.security.rate_limiting
},
"performance": {
"caching": state.feature_config.performance.response_caching,
"compression": state.feature_config.performance.compression
}
})))
}
/// Authentication endpoints (simplified - integrate with full auth service)
async fn auth_login(
State(state): State<AppState>,
Json(credentials): Json<shared::auth::LoginCredentials>,
) -> Result<Json<serde_json::Value>, StatusCode> {
Ok(Json(json!({
"success": true,
"message": "Login successful",
"rbac_enabled": state.rbac_service.is_enabled()
})))
}
async fn auth_register(
State(state): State<AppState>,
Json(user_data): Json<shared::auth::RegisterUserData>,
) -> Result<Json<serde_json::Value>, StatusCode> {
Ok(Json(json!({
"success": true,
"message": "Registration successful"
})))
}
async fn auth_refresh(
State(state): State<AppState>,
Json(request): Json<shared::auth::RefreshTokenRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
Ok(Json(json!({
"success": true,
"message": "Token refreshed"
})))
}
async fn auth_logout(State(state): State<AppState>) -> Result<Json<serde_json::Value>, StatusCode> {
Ok(Json(json!({
"success": true,
"message": "Logged out successfully"
})))
}
/// User profile endpoints
async fn get_user_profile(
State(state): State<AppState>,
) -> Result<Json<serde_json::Value>, StatusCode> {
Ok(Json(json!({
"profile": {
"username": "example_user",
"email": "user@example.com",
"roles": ["user"]
}
})))
}
async fn update_user_profile(
State(state): State<AppState>,
Json(profile): Json<shared::auth::UpdateUserData>,
) -> Result<Json<serde_json::Value>, StatusCode> {
Ok(Json(json!({
"success": true,
"message": "Profile updated successfully"
})))
}
/// Content endpoints (protected by conditional RBAC)
async fn get_content(
axum::extract::Path(content_id): axum::extract::Path<String>,
State(state): State<AppState>,
) -> Result<Json<serde_json::Value>, StatusCode> {
Ok(Json(json!({
"content_id": content_id,
"title": "Example Content",
"body": "This is example content...",
"protection": if state.rbac_service.is_feature_enabled("content_access") {
"RBAC Protected"
} else {
"Basic Role Protected"
}
})))
}
async fn update_content(
axum::extract::Path(content_id): axum::extract::Path<String>,
State(state): State<AppState>,
Json(content): Json<serde_json::Value>,
) -> Result<Json<serde_json::Value>, StatusCode> {
Ok(Json(json!({
"success": true,
"content_id": content_id,
"message": "Content updated successfully"
})))
}
/// Database endpoints (protected by conditional RBAC)
async fn get_database_info(
axum::extract::Path(db_name): axum::extract::Path<String>,
State(state): State<AppState>,
) -> Result<Json<serde_json::Value>, StatusCode> {
Ok(Json(json!({
"database": db_name,
"status": "accessible",
"protection": if state.rbac_service.is_feature_enabled("database_access") {
"RBAC Protected"
} else {
"Basic Role Protected"
}
})))
}
async fn execute_database_query(
axum::extract::Path(db_name): axum::extract::Path<String>,
State(state): State<AppState>,
Json(query): Json<serde_json::Value>,
) -> Result<Json<serde_json::Value>, StatusCode> {
Ok(Json(json!({
"success": true,
"database": db_name,
"message": "Query executed successfully"
})))
}
/// File endpoints (protected by conditional RBAC)
async fn read_file(
axum::extract::Path(file_path): axum::extract::Path<String>,
State(state): State<AppState>,
) -> Result<Json<serde_json::Value>, StatusCode> {
Ok(Json(json!({
"file_path": file_path,
"content": "File content here...",
"protection": if state.rbac_service.is_feature_enabled("file_access") {
"RBAC Protected"
} else {
"Basic Role Protected"
}
})))
}
async fn write_file(
axum::extract::Path(file_path): axum::extract::Path<String>,
State(state): State<AppState>,
Json(request): Json<serde_json::Value>,
) -> Result<Json<serde_json::Value>, StatusCode> {
Ok(Json(json!({
"success": true,
"file_path": file_path,
"message": "File written successfully"
})))
}
/// Admin endpoints (protected by conditional RBAC categories)
async fn list_users(State(state): State<AppState>) -> Result<Json<serde_json::Value>, StatusCode> {
Ok(Json(json!({
"users": [],
"protection": if state.rbac_service.is_feature_enabled("categories") {
"RBAC Category Protected (admin)"
} else {
"Basic Admin Role Protected"
}
})))
}
async fn get_user(
axum::extract::Path(user_id): axum::extract::Path<uuid::Uuid>,
State(state): State<AppState>,
) -> Result<Json<serde_json::Value>, StatusCode> {
Ok(Json(json!({
"user_id": user_id,
"user": {},
"message": "User retrieved successfully"
})))
}
async fn update_user(
axum::extract::Path(user_id): axum::extract::Path<uuid::Uuid>,
State(state): State<AppState>,
Json(user_data): Json<serde_json::Value>,
) -> Result<Json<serde_json::Value>, StatusCode> {
Ok(Json(json!({
"success": true,
"user_id": user_id,
"message": "User updated successfully"
})))
}
async fn delete_user(
axum::extract::Path(user_id): axum::extract::Path<uuid::Uuid>,
State(state): State<AppState>,
) -> Result<Json<serde_json::Value>, StatusCode> {
Ok(Json(json!({
"success": true,
"user_id": user_id,
"message": "User deleted successfully"
})))
}
/// CLI entry point for the RBAC-compatible server
#[tokio::main]
async fn main() -> Result<()> {
// Initialize tracing
tracing_subscriber::fmt::init();
// Load configuration
let config = Config::load().await?;
// Validate configuration
config.validate()?;
// Create and start server
let server = RBACCompatibleServer::new(config).await?;
server.start().await?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_server_creation() {
let config = Config::default();
// This would require a test database setup
// For now, just test that the structure compiles
assert!(true);
}
#[test]
fn test_conditional_middleware() {
let feature_config = Arc::new(FeatureConfig::default());
let rbac_service = ConditionalRBACService {
rbac_service: None,
rbac_repository: None,
feature_config,
};
// Test that middleware functions return None when RBAC is disabled
assert!(
rbac_service
.database_access_middleware("test".to_string(), "read".to_string())
.is_none()
);
assert!(
rbac_service
.file_access_middleware("test".to_string(), "read".to_string())
.is_none()
);
assert!(
rbac_service
.category_access_middleware(vec!["admin".to_string()])
.is_none()
);
}
}