Rustelo/server/src/rbac_main.rs
2025-07-07 23:05:19 +01:00

538 lines
17 KiB
Rust

use anyhow::Result;
use axum::{
Router,
extract::State,
http::{HeaderMap, StatusCode},
middleware,
response::Json,
routing::{get, post},
};
use serde_json::json;
use std::sync::Arc;
use tokio::time::{Duration as TokioDuration, interval};
use tower::ServiceBuilder;
use tower_http::cors::CorsLayer;
use tower_http::trace::TraceLayer;
use crate::auth::{JwtService, RBACConfigLoader, RBACRepository, RBACService, auth_middleware};
use crate::database::{Database, DatabaseConfig, DatabasePool};
use crate::examples::rbac_integration::{
AppState, create_rbac_routes, initialize_rbac_system, setup_rbac_middleware,
};
use std::time::Duration as StdDuration;
/// Main server configuration with RBAC
pub struct RBACServer {
pub app_state: AppState,
pub host: String,
pub port: u16,
}
impl RBACServer {
/// Create a new RBAC-enabled server
pub async fn new(
database_url: &str,
rbac_config_path: &str,
jwt_secret: &str,
host: String,
port: u16,
) -> Result<Self> {
// Initialize database connection using new abstraction
let database_config = DatabaseConfig {
url: database_url.to_string(),
max_connections: 20,
min_connections: 1,
connect_timeout: StdDuration::from_secs(30),
idle_timeout: StdDuration::from_secs(600),
max_lifetime: StdDuration::from_secs(3600),
};
let database_pool = DatabasePool::new(&database_config).await?;
let database = Database::new(database_pool.clone());
// Initialize repositories using new database abstraction
let auth_repository = Arc::new(crate::database::auth::AuthRepository::new(
database.create_connection(),
));
let rbac_repository = Arc::new(RBACRepository::from_database_pool(&database_pool));
// Initialize JWT service
let jwt_service = Arc::new(
JwtService::new()
.map_err(|e| anyhow::anyhow!("Failed to create JWT service: {}", e))?,
);
// Initialize RBAC service
let rbac_service = Arc::new(RBACService::new(rbac_repository.clone()));
// Load RBAC configuration
let config_loader = RBACConfigLoader::new(rbac_config_path);
if !config_loader.config_exists() {
println!("Creating default RBAC configuration...");
config_loader.create_default_config().await?;
}
// Load and save config to database
let rbac_config = config_loader.load_from_file().await?;
rbac_service
.save_rbac_config("default", &rbac_config, Some("Server initialization"))
.await?;
println!(
"RBAC system initialized with {} rules",
rbac_config.rules.len()
);
let app_state = AppState {
rbac_service,
rbac_repository,
auth_repository,
jwt_service,
};
Ok(Self {
app_state,
host,
port,
})
}
/// Build the application router with RBAC middleware
pub fn build_router(&self) -> Router {
let app = Router::new()
// Health check endpoint
.route("/health", get(health_check))
.route("/api/health", get(api_health_check))
// Authentication routes (no RBAC required)
.route("/api/auth/login", post(auth_login))
.route("/api/auth/register", post(auth_register))
.route("/api/auth/refresh", post(auth_refresh))
.route("/api/auth/logout", post(auth_logout))
// User profile routes (basic auth required)
.route("/api/user/profile", get(get_user_profile))
.route("/api/user/profile", post(update_user_profile))
// RBAC management routes (admin only)
.route("/api/rbac/config", get(get_rbac_config))
.route("/api/rbac/config", post(update_rbac_config))
.route("/api/rbac/categories", get(list_categories))
.route("/api/rbac/categories", post(create_category))
.route("/api/rbac/tags", get(list_tags))
.route("/api/rbac/tags", post(create_tag))
.route(
"/api/rbac/users/:user_id/categories",
get(get_user_categories),
)
.route(
"/api/rbac/users/:user_id/categories",
post(assign_user_category),
)
.route("/api/rbac/users/:user_id/tags", get(get_user_tags))
.route("/api/rbac/users/:user_id/tags", post(assign_user_tag))
.route("/api/rbac/audit/:user_id", get(get_access_audit))
// Merge with RBAC-protected routes
.merge(create_rbac_routes(self.app_state.clone()))
// Apply global middleware
.layer(
ServiceBuilder::new()
.layer(TraceLayer::new_for_http())
.layer(CorsLayer::permissive())
.layer(middleware::from_fn_with_state(
(
self.app_state.jwt_service.clone(),
self.app_state.auth_repository.clone(),
),
auth_middleware,
))
.layer(middleware::from_fn_with_state(
self.app_state.rbac_service.clone(),
crate::auth::rbac_middleware::rbac_middleware,
)),
)
.with_state(self.app_state.clone());
// Apply RBAC middleware to specific routes
setup_rbac_middleware(app)
}
/// Start the server
pub async fn start(&self) -> Result<()> {
let app = self.build_router();
let addr = format!("{}:{}", self.host, self.port);
println!("Starting RBAC-enabled server on {}", addr);
// Start background tasks
let cleanup_state = self.app_state.clone();
tokio::spawn(async move {
let mut interval = interval(Duration::from_secs(300)); // 5 minutes
loop {
interval.tick().await;
if let Err(e) = cleanup_state.rbac_service.cleanup_expired_cache().await {
eprintln!("Error cleaning up expired cache: {}", e);
}
}
});
// Start the server
let listener = tokio::net::TcpListener::bind(&addr).await?;
axum::serve(listener, app).await?;
Ok(())
}
}
/// Health check endpoint
async fn health_check() -> Result<Json<serde_json::Value>, StatusCode> {
Ok(Json(json!({
"status": "ok",
"timestamp": chrono::Utc::now(),
"service": "rustelo-rbac"
})))
}
/// API health check with more details
async fn api_health_check(
State(state): State<AppState>,
) -> Result<Json<serde_json::Value>, StatusCode> {
// Check database connectivity
let db_status = match state.rbac_repository.get_rbac_config("default").await {
Ok(_) => "connected",
Err(_) => "disconnected",
};
Ok(Json(json!({
"status": "ok",
"timestamp": chrono::Utc::now(),
"service": "rustelo-rbac",
"database": db_status,
"rbac_enabled": true
})))
}
/// Authentication login endpoint
async fn auth_login(
State(state): State<AppState>,
Json(credentials): Json<shared::auth::LoginCredentials>,
) -> Result<Json<serde_json::Value>, StatusCode> {
// This is a simplified example - in a real implementation,
// you'd use the full AuthService
Ok(Json(json!({
"success": true,
"message": "Login endpoint - implement with AuthService",
"email": credentials.email
})))
}
/// Authentication register endpoint
async fn auth_register(
State(state): State<AppState>,
Json(user_data): Json<shared::auth::RegisterUserData>,
) -> Result<Json<serde_json::Value>, StatusCode> {
Ok(Json(json!({
"success": true,
"message": "Register endpoint - implement with AuthService",
"email": user_data.email
})))
}
/// Token refresh endpoint
async fn auth_refresh(
State(state): State<AppState>,
Json(request): Json<shared::auth::RefreshTokenRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
Ok(Json(json!({
"success": true,
"message": "Refresh endpoint - implement with AuthService"
})))
}
/// Logout endpoint
async fn auth_logout(State(state): State<AppState>) -> Result<Json<serde_json::Value>, StatusCode> {
Ok(Json(json!({
"success": true,
"message": "Logged out successfully"
})))
}
/// Get user profile
async fn get_user_profile(
State(state): State<AppState>,
) -> Result<Json<serde_json::Value>, StatusCode> {
Ok(Json(json!({
"success": true,
"message": "User profile endpoint - implement with AuthService"
})))
}
/// Update user profile
async fn update_user_profile(
State(state): State<AppState>,
Json(profile): Json<shared::auth::UpdateUserData>,
) -> Result<Json<serde_json::Value>, StatusCode> {
Ok(Json(json!({
"success": true,
"message": "Profile update endpoint - implement with AuthService"
})))
}
/// Get RBAC configuration
async fn get_rbac_config(
State(state): State<AppState>,
) -> Result<Json<serde_json::Value>, StatusCode> {
match state.rbac_service.get_rbac_config("default").await {
Ok(config) => Ok(Json(json!({
"success": true,
"config": config
}))),
Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
}
}
/// Update RBAC configuration
async fn update_rbac_config(
State(state): State<AppState>,
Json(config): Json<shared::auth::RBACConfig>,
) -> Result<Json<serde_json::Value>, StatusCode> {
match state
.rbac_service
.save_rbac_config("default", &config, Some("Updated via API"))
.await
{
Ok(_) => Ok(Json(json!({
"success": true,
"message": "RBAC configuration updated"
}))),
Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
}
}
/// List all categories
async fn list_categories(
State(state): State<AppState>,
) -> Result<Json<serde_json::Value>, StatusCode> {
Ok(Json(json!({
"success": true,
"categories": ["admin", "editor", "viewer", "finance", "hr", "it"],
"message": "Categories retrieved successfully"
})))
}
/// Create a new category
async fn create_category(
State(state): State<AppState>,
Json(category): Json<serde_json::Value>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let name = category
.get("name")
.and_then(|n| n.as_str())
.unwrap_or("unnamed");
let description = category.get("description").and_then(|d| d.as_str());
match state
.rbac_repository
.create_category(name, description, None)
.await
{
Ok(created_category) => Ok(Json(json!({
"success": true,
"category": {
"id": created_category.id,
"name": created_category.name,
"description": created_category.description
},
"message": "Category created successfully"
}))),
Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
}
}
/// List all tags
async fn list_tags(State(state): State<AppState>) -> Result<Json<serde_json::Value>, StatusCode> {
Ok(Json(json!({
"success": true,
"tags": ["sensitive", "public", "internal", "confidential", "restricted", "temporary"],
"message": "Tags retrieved successfully"
})))
}
/// Create a new tag
async fn create_tag(
State(state): State<AppState>,
Json(tag): Json<serde_json::Value>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let name = tag
.get("name")
.and_then(|n| n.as_str())
.unwrap_or("unnamed");
let description = tag.get("description").and_then(|d| d.as_str());
let color = tag.get("color").and_then(|c| c.as_str());
match state
.rbac_repository
.create_tag(name, description, color)
.await
{
Ok(created_tag) => Ok(Json(json!({
"success": true,
"tag": {
"id": created_tag.id,
"name": created_tag.name,
"description": created_tag.description,
"color": created_tag.color
},
"message": "Tag created successfully"
}))),
Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
}
}
/// Get user categories
async fn get_user_categories(
axum::extract::Path(user_id): axum::extract::Path<uuid::Uuid>,
State(state): State<AppState>,
) -> Result<Json<serde_json::Value>, StatusCode> {
match state.rbac_repository.get_user_categories(user_id).await {
Ok(categories) => Ok(Json(json!({
"success": true,
"user_id": user_id,
"categories": categories
}))),
Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
}
}
/// Assign category to user
async fn assign_user_category(
axum::extract::Path(user_id): axum::extract::Path<uuid::Uuid>,
State(state): State<AppState>,
Json(request): Json<serde_json::Value>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let category_name = request
.get("category")
.and_then(|c| c.as_str())
.unwrap_or("");
match state
.rbac_service
.assign_category_to_user(user_id, category_name, None, None)
.await
{
Ok(_) => Ok(Json(json!({
"success": true,
"message": "Category assigned successfully"
}))),
Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
}
}
/// Get user tags
async fn get_user_tags(
axum::extract::Path(user_id): axum::extract::Path<uuid::Uuid>,
State(state): State<AppState>,
) -> Result<Json<serde_json::Value>, StatusCode> {
match state.rbac_repository.get_user_tags(user_id).await {
Ok(tags) => Ok(Json(json!({
"success": true,
"user_id": user_id,
"tags": tags
}))),
Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
}
}
/// Assign tag to user
async fn assign_user_tag(
axum::extract::Path(user_id): axum::extract::Path<uuid::Uuid>,
State(state): State<AppState>,
Json(request): Json<serde_json::Value>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let tag_name = request.get("tag").and_then(|t| t.as_str()).unwrap_or("");
match state
.rbac_service
.assign_tag_to_user(user_id, tag_name, None, None)
.await
{
Ok(_) => Ok(Json(json!({
"success": true,
"message": "Tag assigned successfully"
}))),
Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
}
}
/// Get access audit for user
async fn get_access_audit(
axum::extract::Path(user_id): axum::extract::Path<uuid::Uuid>,
State(state): State<AppState>,
) -> Result<Json<serde_json::Value>, StatusCode> {
match state
.rbac_service
.get_user_access_history(user_id, 100)
.await
{
Ok(history) => Ok(Json(json!({
"success": true,
"user_id": user_id,
"audit_log": history
}))),
Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
}
}
/// CLI entry point for RBAC server
#[tokio::main]
async fn main() -> Result<()> {
// Initialize tracing
tracing_subscriber::fmt::init();
// Load configuration from environment
let database_url = std::env::var("DATABASE_URL")
.unwrap_or_else(|_| "postgres://dev:dev@localhost:5432/rustelo_dev".to_string());
let rbac_config_path =
std::env::var("RBAC_CONFIG_PATH").unwrap_or_else(|_| "config/rbac.toml".to_string());
let jwt_secret = std::env::var("JWT_SECRET")
.unwrap_or_else(|_| "your-super-secret-jwt-key-change-this-in-production".to_string());
let host = std::env::var("SERVER_HOST").unwrap_or_else(|_| "127.0.0.1".to_string());
let port = std::env::var("SERVER_PORT")
.unwrap_or_else(|_| "3030".to_string())
.parse::<u16>()
.unwrap_or(3030);
println!("Initializing RBAC server...");
println!("Database URL: {}", database_url);
println!("RBAC Config: {}", rbac_config_path);
println!("Server: {}:{}", host, port);
// Create and start server
let server = RBACServer::new(&database_url, &rbac_config_path, &jwt_secret, host, port).await?;
server.start().await?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use axum::body::Body;
use axum::http::{Method, Request, StatusCode};
use tower::ServiceExt;
#[tokio::test]
async fn test_health_check() {
let response = health_check().await.unwrap();
assert!(response.0.get("status").unwrap().as_str().unwrap() == "ok");
}
#[tokio::test]
async fn test_server_creation() {
// This would require a test database setup
// For now, just test that the structure compiles
assert!(true);
}
}