538 lines
17 KiB
Rust
538 lines
17 KiB
Rust
use anyhow::Result;
|
|
use axum::{
|
|
Router,
|
|
extract::State,
|
|
http::{HeaderMap, StatusCode},
|
|
middleware,
|
|
response::Json,
|
|
routing::{get, post},
|
|
};
|
|
use serde_json::json;
|
|
|
|
use std::sync::Arc;
|
|
use tokio::time::{Duration as TokioDuration, interval};
|
|
use tower::ServiceBuilder;
|
|
use tower_http::cors::CorsLayer;
|
|
use tower_http::trace::TraceLayer;
|
|
|
|
use crate::auth::{JwtService, RBACConfigLoader, RBACRepository, RBACService, auth_middleware};
|
|
use crate::database::{Database, DatabaseConfig, DatabasePool};
|
|
use crate::examples::rbac_integration::{
|
|
AppState, create_rbac_routes, initialize_rbac_system, setup_rbac_middleware,
|
|
};
|
|
use std::time::Duration as StdDuration;
|
|
|
|
/// Main server configuration with RBAC
|
|
pub struct RBACServer {
|
|
pub app_state: AppState,
|
|
pub host: String,
|
|
pub port: u16,
|
|
}
|
|
|
|
impl RBACServer {
|
|
/// Create a new RBAC-enabled server
|
|
pub async fn new(
|
|
database_url: &str,
|
|
rbac_config_path: &str,
|
|
jwt_secret: &str,
|
|
host: String,
|
|
port: u16,
|
|
) -> Result<Self> {
|
|
// Initialize database connection using new abstraction
|
|
let database_config = DatabaseConfig {
|
|
url: database_url.to_string(),
|
|
max_connections: 20,
|
|
min_connections: 1,
|
|
connect_timeout: StdDuration::from_secs(30),
|
|
idle_timeout: StdDuration::from_secs(600),
|
|
max_lifetime: StdDuration::from_secs(3600),
|
|
};
|
|
|
|
let database_pool = DatabasePool::new(&database_config).await?;
|
|
let database = Database::new(database_pool.clone());
|
|
|
|
// Initialize repositories using new database abstraction
|
|
let auth_repository = Arc::new(crate::database::auth::AuthRepository::new(
|
|
database.create_connection(),
|
|
));
|
|
let rbac_repository = Arc::new(RBACRepository::from_database_pool(&database_pool));
|
|
|
|
// Initialize JWT service
|
|
let jwt_service = Arc::new(
|
|
JwtService::new()
|
|
.map_err(|e| anyhow::anyhow!("Failed to create JWT service: {}", e))?,
|
|
);
|
|
|
|
// Initialize RBAC service
|
|
let rbac_service = Arc::new(RBACService::new(rbac_repository.clone()));
|
|
|
|
// Load RBAC configuration
|
|
let config_loader = RBACConfigLoader::new(rbac_config_path);
|
|
if !config_loader.config_exists() {
|
|
println!("Creating default RBAC configuration...");
|
|
config_loader.create_default_config().await?;
|
|
}
|
|
|
|
// Load and save config to database
|
|
let rbac_config = config_loader.load_from_file().await?;
|
|
rbac_service
|
|
.save_rbac_config("default", &rbac_config, Some("Server initialization"))
|
|
.await?;
|
|
|
|
println!(
|
|
"RBAC system initialized with {} rules",
|
|
rbac_config.rules.len()
|
|
);
|
|
|
|
let app_state = AppState {
|
|
rbac_service,
|
|
rbac_repository,
|
|
auth_repository,
|
|
jwt_service,
|
|
};
|
|
|
|
Ok(Self {
|
|
app_state,
|
|
host,
|
|
port,
|
|
})
|
|
}
|
|
|
|
/// Build the application router with RBAC middleware
|
|
pub fn build_router(&self) -> Router {
|
|
let app = Router::new()
|
|
// Health check endpoint
|
|
.route("/health", get(health_check))
|
|
.route("/api/health", get(api_health_check))
|
|
// Authentication routes (no RBAC required)
|
|
.route("/api/auth/login", post(auth_login))
|
|
.route("/api/auth/register", post(auth_register))
|
|
.route("/api/auth/refresh", post(auth_refresh))
|
|
.route("/api/auth/logout", post(auth_logout))
|
|
// User profile routes (basic auth required)
|
|
.route("/api/user/profile", get(get_user_profile))
|
|
.route("/api/user/profile", post(update_user_profile))
|
|
// RBAC management routes (admin only)
|
|
.route("/api/rbac/config", get(get_rbac_config))
|
|
.route("/api/rbac/config", post(update_rbac_config))
|
|
.route("/api/rbac/categories", get(list_categories))
|
|
.route("/api/rbac/categories", post(create_category))
|
|
.route("/api/rbac/tags", get(list_tags))
|
|
.route("/api/rbac/tags", post(create_tag))
|
|
.route(
|
|
"/api/rbac/users/:user_id/categories",
|
|
get(get_user_categories),
|
|
)
|
|
.route(
|
|
"/api/rbac/users/:user_id/categories",
|
|
post(assign_user_category),
|
|
)
|
|
.route("/api/rbac/users/:user_id/tags", get(get_user_tags))
|
|
.route("/api/rbac/users/:user_id/tags", post(assign_user_tag))
|
|
.route("/api/rbac/audit/:user_id", get(get_access_audit))
|
|
// Merge with RBAC-protected routes
|
|
.merge(create_rbac_routes(self.app_state.clone()))
|
|
// Apply global middleware
|
|
.layer(
|
|
ServiceBuilder::new()
|
|
.layer(TraceLayer::new_for_http())
|
|
.layer(CorsLayer::permissive())
|
|
.layer(middleware::from_fn_with_state(
|
|
(
|
|
self.app_state.jwt_service.clone(),
|
|
self.app_state.auth_repository.clone(),
|
|
),
|
|
auth_middleware,
|
|
))
|
|
.layer(middleware::from_fn_with_state(
|
|
self.app_state.rbac_service.clone(),
|
|
crate::auth::rbac_middleware::rbac_middleware,
|
|
)),
|
|
)
|
|
.with_state(self.app_state.clone());
|
|
|
|
// Apply RBAC middleware to specific routes
|
|
setup_rbac_middleware(app)
|
|
}
|
|
|
|
/// Start the server
|
|
pub async fn start(&self) -> Result<()> {
|
|
let app = self.build_router();
|
|
let addr = format!("{}:{}", self.host, self.port);
|
|
|
|
println!("Starting RBAC-enabled server on {}", addr);
|
|
|
|
// Start background tasks
|
|
let cleanup_state = self.app_state.clone();
|
|
tokio::spawn(async move {
|
|
let mut interval = interval(Duration::from_secs(300)); // 5 minutes
|
|
loop {
|
|
interval.tick().await;
|
|
if let Err(e) = cleanup_state.rbac_service.cleanup_expired_cache().await {
|
|
eprintln!("Error cleaning up expired cache: {}", e);
|
|
}
|
|
}
|
|
});
|
|
|
|
// Start the server
|
|
let listener = tokio::net::TcpListener::bind(&addr).await?;
|
|
axum::serve(listener, app).await?;
|
|
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
/// Health check endpoint
|
|
async fn health_check() -> Result<Json<serde_json::Value>, StatusCode> {
|
|
Ok(Json(json!({
|
|
"status": "ok",
|
|
"timestamp": chrono::Utc::now(),
|
|
"service": "rustelo-rbac"
|
|
})))
|
|
}
|
|
|
|
/// API health check with more details
|
|
async fn api_health_check(
|
|
State(state): State<AppState>,
|
|
) -> Result<Json<serde_json::Value>, StatusCode> {
|
|
// Check database connectivity
|
|
let db_status = match state.rbac_repository.get_rbac_config("default").await {
|
|
Ok(_) => "connected",
|
|
Err(_) => "disconnected",
|
|
};
|
|
|
|
Ok(Json(json!({
|
|
"status": "ok",
|
|
"timestamp": chrono::Utc::now(),
|
|
"service": "rustelo-rbac",
|
|
"database": db_status,
|
|
"rbac_enabled": true
|
|
})))
|
|
}
|
|
|
|
/// Authentication login endpoint
|
|
async fn auth_login(
|
|
State(state): State<AppState>,
|
|
Json(credentials): Json<shared::auth::LoginCredentials>,
|
|
) -> Result<Json<serde_json::Value>, StatusCode> {
|
|
// This is a simplified example - in a real implementation,
|
|
// you'd use the full AuthService
|
|
Ok(Json(json!({
|
|
"success": true,
|
|
"message": "Login endpoint - implement with AuthService",
|
|
"email": credentials.email
|
|
})))
|
|
}
|
|
|
|
/// Authentication register endpoint
|
|
async fn auth_register(
|
|
State(state): State<AppState>,
|
|
Json(user_data): Json<shared::auth::RegisterUserData>,
|
|
) -> Result<Json<serde_json::Value>, StatusCode> {
|
|
Ok(Json(json!({
|
|
"success": true,
|
|
"message": "Register endpoint - implement with AuthService",
|
|
"email": user_data.email
|
|
})))
|
|
}
|
|
|
|
/// Token refresh endpoint
|
|
async fn auth_refresh(
|
|
State(state): State<AppState>,
|
|
Json(request): Json<shared::auth::RefreshTokenRequest>,
|
|
) -> Result<Json<serde_json::Value>, StatusCode> {
|
|
Ok(Json(json!({
|
|
"success": true,
|
|
"message": "Refresh endpoint - implement with AuthService"
|
|
})))
|
|
}
|
|
|
|
/// Logout endpoint
|
|
async fn auth_logout(State(state): State<AppState>) -> Result<Json<serde_json::Value>, StatusCode> {
|
|
Ok(Json(json!({
|
|
"success": true,
|
|
"message": "Logged out successfully"
|
|
})))
|
|
}
|
|
|
|
/// Get user profile
|
|
async fn get_user_profile(
|
|
State(state): State<AppState>,
|
|
) -> Result<Json<serde_json::Value>, StatusCode> {
|
|
Ok(Json(json!({
|
|
"success": true,
|
|
"message": "User profile endpoint - implement with AuthService"
|
|
})))
|
|
}
|
|
|
|
/// Update user profile
|
|
async fn update_user_profile(
|
|
State(state): State<AppState>,
|
|
Json(profile): Json<shared::auth::UpdateUserData>,
|
|
) -> Result<Json<serde_json::Value>, StatusCode> {
|
|
Ok(Json(json!({
|
|
"success": true,
|
|
"message": "Profile update endpoint - implement with AuthService"
|
|
})))
|
|
}
|
|
|
|
/// Get RBAC configuration
|
|
async fn get_rbac_config(
|
|
State(state): State<AppState>,
|
|
) -> Result<Json<serde_json::Value>, StatusCode> {
|
|
match state.rbac_service.get_rbac_config("default").await {
|
|
Ok(config) => Ok(Json(json!({
|
|
"success": true,
|
|
"config": config
|
|
}))),
|
|
Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
|
|
}
|
|
}
|
|
|
|
/// Update RBAC configuration
|
|
async fn update_rbac_config(
|
|
State(state): State<AppState>,
|
|
Json(config): Json<shared::auth::RBACConfig>,
|
|
) -> Result<Json<serde_json::Value>, StatusCode> {
|
|
match state
|
|
.rbac_service
|
|
.save_rbac_config("default", &config, Some("Updated via API"))
|
|
.await
|
|
{
|
|
Ok(_) => Ok(Json(json!({
|
|
"success": true,
|
|
"message": "RBAC configuration updated"
|
|
}))),
|
|
Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
|
|
}
|
|
}
|
|
|
|
/// List all categories
|
|
async fn list_categories(
|
|
State(state): State<AppState>,
|
|
) -> Result<Json<serde_json::Value>, StatusCode> {
|
|
Ok(Json(json!({
|
|
"success": true,
|
|
"categories": ["admin", "editor", "viewer", "finance", "hr", "it"],
|
|
"message": "Categories retrieved successfully"
|
|
})))
|
|
}
|
|
|
|
/// Create a new category
|
|
async fn create_category(
|
|
State(state): State<AppState>,
|
|
Json(category): Json<serde_json::Value>,
|
|
) -> Result<Json<serde_json::Value>, StatusCode> {
|
|
let name = category
|
|
.get("name")
|
|
.and_then(|n| n.as_str())
|
|
.unwrap_or("unnamed");
|
|
let description = category.get("description").and_then(|d| d.as_str());
|
|
|
|
match state
|
|
.rbac_repository
|
|
.create_category(name, description, None)
|
|
.await
|
|
{
|
|
Ok(created_category) => Ok(Json(json!({
|
|
"success": true,
|
|
"category": {
|
|
"id": created_category.id,
|
|
"name": created_category.name,
|
|
"description": created_category.description
|
|
},
|
|
"message": "Category created successfully"
|
|
}))),
|
|
Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
|
|
}
|
|
}
|
|
|
|
/// List all tags
|
|
async fn list_tags(State(state): State<AppState>) -> Result<Json<serde_json::Value>, StatusCode> {
|
|
Ok(Json(json!({
|
|
"success": true,
|
|
"tags": ["sensitive", "public", "internal", "confidential", "restricted", "temporary"],
|
|
"message": "Tags retrieved successfully"
|
|
})))
|
|
}
|
|
|
|
/// Create a new tag
|
|
async fn create_tag(
|
|
State(state): State<AppState>,
|
|
Json(tag): Json<serde_json::Value>,
|
|
) -> Result<Json<serde_json::Value>, StatusCode> {
|
|
let name = tag
|
|
.get("name")
|
|
.and_then(|n| n.as_str())
|
|
.unwrap_or("unnamed");
|
|
let description = tag.get("description").and_then(|d| d.as_str());
|
|
let color = tag.get("color").and_then(|c| c.as_str());
|
|
|
|
match state
|
|
.rbac_repository
|
|
.create_tag(name, description, color)
|
|
.await
|
|
{
|
|
Ok(created_tag) => Ok(Json(json!({
|
|
"success": true,
|
|
"tag": {
|
|
"id": created_tag.id,
|
|
"name": created_tag.name,
|
|
"description": created_tag.description,
|
|
"color": created_tag.color
|
|
},
|
|
"message": "Tag created successfully"
|
|
}))),
|
|
Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
|
|
}
|
|
}
|
|
|
|
/// Get user categories
|
|
async fn get_user_categories(
|
|
axum::extract::Path(user_id): axum::extract::Path<uuid::Uuid>,
|
|
State(state): State<AppState>,
|
|
) -> Result<Json<serde_json::Value>, StatusCode> {
|
|
match state.rbac_repository.get_user_categories(user_id).await {
|
|
Ok(categories) => Ok(Json(json!({
|
|
"success": true,
|
|
"user_id": user_id,
|
|
"categories": categories
|
|
}))),
|
|
Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
|
|
}
|
|
}
|
|
|
|
/// Assign category to user
|
|
async fn assign_user_category(
|
|
axum::extract::Path(user_id): axum::extract::Path<uuid::Uuid>,
|
|
State(state): State<AppState>,
|
|
Json(request): Json<serde_json::Value>,
|
|
) -> Result<Json<serde_json::Value>, StatusCode> {
|
|
let category_name = request
|
|
.get("category")
|
|
.and_then(|c| c.as_str())
|
|
.unwrap_or("");
|
|
|
|
match state
|
|
.rbac_service
|
|
.assign_category_to_user(user_id, category_name, None, None)
|
|
.await
|
|
{
|
|
Ok(_) => Ok(Json(json!({
|
|
"success": true,
|
|
"message": "Category assigned successfully"
|
|
}))),
|
|
Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
|
|
}
|
|
}
|
|
|
|
/// Get user tags
|
|
async fn get_user_tags(
|
|
axum::extract::Path(user_id): axum::extract::Path<uuid::Uuid>,
|
|
State(state): State<AppState>,
|
|
) -> Result<Json<serde_json::Value>, StatusCode> {
|
|
match state.rbac_repository.get_user_tags(user_id).await {
|
|
Ok(tags) => Ok(Json(json!({
|
|
"success": true,
|
|
"user_id": user_id,
|
|
"tags": tags
|
|
}))),
|
|
Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
|
|
}
|
|
}
|
|
|
|
/// Assign tag to user
|
|
async fn assign_user_tag(
|
|
axum::extract::Path(user_id): axum::extract::Path<uuid::Uuid>,
|
|
State(state): State<AppState>,
|
|
Json(request): Json<serde_json::Value>,
|
|
) -> Result<Json<serde_json::Value>, StatusCode> {
|
|
let tag_name = request.get("tag").and_then(|t| t.as_str()).unwrap_or("");
|
|
|
|
match state
|
|
.rbac_service
|
|
.assign_tag_to_user(user_id, tag_name, None, None)
|
|
.await
|
|
{
|
|
Ok(_) => Ok(Json(json!({
|
|
"success": true,
|
|
"message": "Tag assigned successfully"
|
|
}))),
|
|
Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
|
|
}
|
|
}
|
|
|
|
/// Get access audit for user
|
|
async fn get_access_audit(
|
|
axum::extract::Path(user_id): axum::extract::Path<uuid::Uuid>,
|
|
State(state): State<AppState>,
|
|
) -> Result<Json<serde_json::Value>, StatusCode> {
|
|
match state
|
|
.rbac_service
|
|
.get_user_access_history(user_id, 100)
|
|
.await
|
|
{
|
|
Ok(history) => Ok(Json(json!({
|
|
"success": true,
|
|
"user_id": user_id,
|
|
"audit_log": history
|
|
}))),
|
|
Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
|
|
}
|
|
}
|
|
|
|
/// CLI entry point for RBAC server
|
|
#[tokio::main]
|
|
async fn main() -> Result<()> {
|
|
// Initialize tracing
|
|
tracing_subscriber::fmt::init();
|
|
|
|
// Load configuration from environment
|
|
let database_url = std::env::var("DATABASE_URL")
|
|
.unwrap_or_else(|_| "postgres://dev:dev@localhost:5432/rustelo_dev".to_string());
|
|
|
|
let rbac_config_path =
|
|
std::env::var("RBAC_CONFIG_PATH").unwrap_or_else(|_| "config/rbac.toml".to_string());
|
|
|
|
let jwt_secret = std::env::var("JWT_SECRET")
|
|
.unwrap_or_else(|_| "your-super-secret-jwt-key-change-this-in-production".to_string());
|
|
|
|
let host = std::env::var("SERVER_HOST").unwrap_or_else(|_| "127.0.0.1".to_string());
|
|
let port = std::env::var("SERVER_PORT")
|
|
.unwrap_or_else(|_| "3030".to_string())
|
|
.parse::<u16>()
|
|
.unwrap_or(3030);
|
|
|
|
println!("Initializing RBAC server...");
|
|
println!("Database URL: {}", database_url);
|
|
println!("RBAC Config: {}", rbac_config_path);
|
|
println!("Server: {}:{}", host, port);
|
|
|
|
// Create and start server
|
|
let server = RBACServer::new(&database_url, &rbac_config_path, &jwt_secret, host, port).await?;
|
|
server.start().await?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use axum::body::Body;
|
|
use axum::http::{Method, Request, StatusCode};
|
|
use tower::ServiceExt;
|
|
|
|
#[tokio::test]
|
|
async fn test_health_check() {
|
|
let response = health_check().await.unwrap();
|
|
assert!(response.0.get("status").unwrap().as_str().unwrap() == "ok");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_server_creation() {
|
|
// This would require a test database setup
|
|
// For now, just test that the structure compiles
|
|
assert!(true);
|
|
}
|
|
}
|