652 lines
21 KiB
Rust
652 lines
21 KiB
Rust
use anyhow::Result;
|
|
use axum::{
|
|
Router,
|
|
extract::State,
|
|
http::{HeaderMap, StatusCode},
|
|
middleware,
|
|
response::Json,
|
|
routing::{get, post},
|
|
};
|
|
use serde::{Deserialize, Serialize};
|
|
use serde_json::json;
|
|
use std::sync::Arc;
|
|
use tokio::time::{Duration as TokioDuration, interval};
|
|
use tower::ServiceBuilder;
|
|
use tower_http::cors::CorsLayer;
|
|
use tower_http::trace::TraceLayer;
|
|
|
|
use crate::auth::{ConditionalRBACService, JwtService, auth_middleware};
|
|
use crate::config::{Config, features::FeatureConfig};
|
|
use crate::database::{Database, DatabaseConfig, DatabasePool};
|
|
use std::time::Duration as StdDuration;
|
|
|
|
/// Main application state with optional RBAC support
|
|
#[derive(Clone)]
|
|
pub struct AppState {
|
|
pub auth_repository: Arc<crate::database::auth::AuthRepository>,
|
|
pub jwt_service: Arc<JwtService>,
|
|
pub rbac_service: Arc<ConditionalRBACService>,
|
|
pub config: Arc<Config>,
|
|
pub feature_config: Arc<FeatureConfig>,
|
|
}
|
|
|
|
/// RBAC-compatible server that can run with or without RBAC features
|
|
pub struct RBACCompatibleServer {
|
|
pub app_state: AppState,
|
|
pub config: Arc<Config>,
|
|
}
|
|
|
|
impl RBACCompatibleServer {
|
|
/// Create a new server with optional RBAC features
|
|
pub async fn new(config: Config) -> Result<Self> {
|
|
let config = Arc::new(config);
|
|
|
|
// Initialize database connection using new abstraction
|
|
let database_config = DatabaseConfig {
|
|
url: config.database.url.clone(),
|
|
max_connections: config.database.max_connections,
|
|
min_connections: config.database.min_connections,
|
|
connect_timeout: StdDuration::from_secs(config.database.connect_timeout),
|
|
idle_timeout: StdDuration::from_secs(config.database.idle_timeout),
|
|
max_lifetime: StdDuration::from_secs(config.database.max_lifetime),
|
|
};
|
|
|
|
let database_pool = DatabasePool::new(&database_config).await?;
|
|
let database = Database::new(database_pool.clone());
|
|
|
|
// Initialize feature configuration
|
|
let mut feature_config = FeatureConfig::from_env();
|
|
|
|
// Override with config file settings
|
|
if config.features.rbac {
|
|
feature_config.enable_rbac();
|
|
}
|
|
if config.features.rbac_database_access {
|
|
feature_config.rbac.database_access = true;
|
|
}
|
|
if config.features.rbac_file_access {
|
|
feature_config.rbac.file_access = true;
|
|
}
|
|
if config.features.rbac_content_access {
|
|
feature_config.rbac.content_access = true;
|
|
}
|
|
if config.features.rbac_categories {
|
|
feature_config.rbac.categories = true;
|
|
}
|
|
if config.features.rbac_tags {
|
|
feature_config.rbac.tags = true;
|
|
}
|
|
if config.features.rbac_caching {
|
|
feature_config.rbac.caching = true;
|
|
}
|
|
if config.features.rbac_audit_logging {
|
|
feature_config.rbac.audit_logging = true;
|
|
}
|
|
|
|
let feature_config = Arc::new(feature_config);
|
|
|
|
// Initialize core authentication services
|
|
let auth_repository = Arc::new(crate::database::auth::AuthRepository::new(
|
|
database.create_connection(),
|
|
));
|
|
|
|
let jwt_service = Arc::new(
|
|
JwtService::new()
|
|
.map_err(|e| anyhow::anyhow!("Failed to create JWT service: {}", e))?,
|
|
);
|
|
|
|
// Initialize conditional RBAC service
|
|
let rbac_config_path = if feature_config.is_rbac_feature_enabled("toml_config") {
|
|
Some("config/rbac.toml")
|
|
} else {
|
|
None
|
|
};
|
|
|
|
let rbac_service = Arc::new(
|
|
ConditionalRBACService::new(&database_pool, feature_config.clone(), rbac_config_path)
|
|
.await?,
|
|
);
|
|
|
|
let app_state = AppState {
|
|
auth_repository,
|
|
jwt_service,
|
|
rbac_service,
|
|
config: config.clone(),
|
|
feature_config,
|
|
};
|
|
|
|
Ok(Self { app_state, config })
|
|
}
|
|
|
|
/// Build the application router with conditional RBAC middleware
|
|
pub fn build_router(&self) -> Router {
|
|
let mut app = Router::new()
|
|
// Health check endpoints
|
|
.route("/health", get(health_check))
|
|
.route("/api/health", get(api_health_check))
|
|
.route("/api/features", get(feature_status))
|
|
// Authentication routes (always available)
|
|
.route("/api/auth/login", post(auth_login))
|
|
.route("/api/auth/register", post(auth_register))
|
|
.route("/api/auth/refresh", post(auth_refresh))
|
|
.route("/api/auth/logout", post(auth_logout))
|
|
// User profile routes
|
|
.route("/api/user/profile", get(get_user_profile))
|
|
.route("/api/user/profile", post(update_user_profile))
|
|
// Content routes (with conditional RBAC protection)
|
|
.route("/api/content/:content_id", get(get_content))
|
|
.route("/api/content/:content_id", post(update_content))
|
|
// Database access routes (with conditional RBAC protection)
|
|
.route("/api/database/:db_name", get(get_database_info))
|
|
.route("/api/database/:db_name/query", post(execute_database_query))
|
|
// File access routes (with conditional RBAC protection)
|
|
.route("/api/files/*path", get(read_file))
|
|
.route("/api/files/*path", post(write_file))
|
|
// Admin routes (always require admin role, optionally enhanced with RBAC)
|
|
.route("/api/admin/users", get(list_users))
|
|
.route("/api/admin/users/:user_id", get(get_user))
|
|
.route("/api/admin/users/:user_id", post(update_user))
|
|
.route(
|
|
"/api/admin/users/:user_id",
|
|
axum::routing::delete(delete_user),
|
|
);
|
|
|
|
// Add RBAC-specific routes if enabled
|
|
if let Some(rbac_routes) = self.app_state.rbac_service.create_rbac_routes() {
|
|
app = app.merge(rbac_routes);
|
|
}
|
|
|
|
// Apply middleware layers
|
|
app = app.layer(
|
|
ServiceBuilder::new()
|
|
.layer(TraceLayer::new_for_http())
|
|
.layer(CorsLayer::permissive())
|
|
// Always apply authentication middleware
|
|
.layer(middleware::from_fn_with_state(
|
|
(
|
|
self.app_state.jwt_service.clone(),
|
|
self.app_state.auth_repository.clone(),
|
|
),
|
|
auth_middleware,
|
|
)),
|
|
);
|
|
|
|
// Apply RBAC middleware conditionally
|
|
app = app.apply_rbac_if_enabled(&self.app_state.rbac_service);
|
|
|
|
// Apply specific RBAC middleware to protected routes
|
|
app = self.apply_route_specific_rbac_middleware(app);
|
|
|
|
app.with_state(self.app_state.clone())
|
|
}
|
|
|
|
/// Apply route-specific RBAC middleware conditionally
|
|
fn apply_route_specific_rbac_middleware(
|
|
&self,
|
|
mut router: Router<AppState>,
|
|
) -> Router<AppState> {
|
|
// Database access protection
|
|
if self
|
|
.app_state
|
|
.rbac_service
|
|
.is_feature_enabled("database_access")
|
|
{
|
|
router = router
|
|
.route(
|
|
"/api/database/:db_name",
|
|
get(get_database_info).layer(middleware::from_fn(
|
|
self.app_state
|
|
.rbac_service
|
|
.database_access_middleware("*".to_string(), "read".to_string())
|
|
.unwrap(),
|
|
)),
|
|
)
|
|
.route(
|
|
"/api/database/:db_name/query",
|
|
post(execute_database_query).layer(middleware::from_fn(
|
|
self.app_state
|
|
.rbac_service
|
|
.database_access_middleware("*".to_string(), "write".to_string())
|
|
.unwrap(),
|
|
)),
|
|
);
|
|
}
|
|
|
|
// File access protection
|
|
if self
|
|
.app_state
|
|
.rbac_service
|
|
.is_feature_enabled("file_access")
|
|
{
|
|
router = router
|
|
.route(
|
|
"/api/files/*path",
|
|
get(read_file).layer(middleware::from_fn(
|
|
self.app_state
|
|
.rbac_service
|
|
.file_access_middleware("*".to_string(), "read".to_string())
|
|
.unwrap(),
|
|
)),
|
|
)
|
|
.route(
|
|
"/api/files/*path",
|
|
post(write_file).layer(middleware::from_fn(
|
|
self.app_state
|
|
.rbac_service
|
|
.file_access_middleware("*".to_string(), "write".to_string())
|
|
.unwrap(),
|
|
)),
|
|
);
|
|
}
|
|
|
|
// Content access protection
|
|
if self
|
|
.app_state
|
|
.rbac_service
|
|
.is_feature_enabled("content_access")
|
|
{
|
|
router = router
|
|
.route(
|
|
"/api/content/:content_id",
|
|
get(get_content).layer(middleware::from_fn(
|
|
self.app_state
|
|
.rbac_service
|
|
.content_access_middleware("*".to_string(), "read".to_string())
|
|
.unwrap(),
|
|
)),
|
|
)
|
|
.route(
|
|
"/api/content/:content_id",
|
|
post(update_content).layer(middleware::from_fn(
|
|
self.app_state
|
|
.rbac_service
|
|
.content_access_middleware("*".to_string(), "write".to_string())
|
|
.unwrap(),
|
|
)),
|
|
);
|
|
}
|
|
|
|
// Admin routes with category protection
|
|
if self.app_state.rbac_service.is_feature_enabled("categories") {
|
|
router = router
|
|
.route(
|
|
"/api/admin/users",
|
|
get(list_users).layer(middleware::from_fn(
|
|
self.app_state
|
|
.rbac_service
|
|
.category_access_middleware(vec!["admin".to_string()])
|
|
.unwrap(),
|
|
)),
|
|
)
|
|
.route(
|
|
"/api/admin/users/:user_id",
|
|
get(get_user).layer(middleware::from_fn(
|
|
self.app_state
|
|
.rbac_service
|
|
.category_access_middleware(vec!["admin".to_string()])
|
|
.unwrap(),
|
|
)),
|
|
);
|
|
}
|
|
|
|
router
|
|
}
|
|
|
|
/// Start the server
|
|
pub async fn start(&self) -> Result<()> {
|
|
let app = self.build_router();
|
|
let addr = self.config.server_address();
|
|
|
|
// Print startup information
|
|
println!("🚀 Starting Rustelo server...");
|
|
println!("📍 Address: {}", addr);
|
|
println!("🌍 Environment: {:?}", self.config.server.environment);
|
|
|
|
if self.app_state.rbac_service.is_enabled() {
|
|
println!("🔐 RBAC System: Enabled");
|
|
let status = self.app_state.rbac_service.get_feature_status();
|
|
println!(
|
|
" └─ Features: {}",
|
|
serde_json::to_string_pretty(&status["features"])?
|
|
);
|
|
} else {
|
|
println!("🔒 RBAC System: Disabled (using basic role-based auth)");
|
|
}
|
|
|
|
// Start background tasks
|
|
self.start_background_tasks().await;
|
|
|
|
// Start the server
|
|
let listener = tokio::net::TcpListener::bind(&addr).await?;
|
|
println!("✅ Server running on {}", addr);
|
|
|
|
axum::serve(listener, app).await?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Start background tasks conditionally
|
|
async fn start_background_tasks(&self) {
|
|
// Start RBAC background tasks if enabled
|
|
self.app_state.rbac_service.start_background_tasks().await;
|
|
|
|
// Start general maintenance tasks
|
|
tokio::spawn(async {
|
|
let mut interval = interval(Duration::from_secs(3600)); // 1 hour
|
|
loop {
|
|
interval.tick().await;
|
|
println!("🧹 Running periodic maintenance tasks...");
|
|
// Add general cleanup tasks here
|
|
}
|
|
});
|
|
|
|
println!("🚀 Background tasks started");
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// Route Handlers
|
|
// =============================================================================
|
|
|
|
/// Health check endpoint
|
|
async fn health_check() -> Result<Json<serde_json::Value>, StatusCode> {
|
|
Ok(Json(json!({
|
|
"status": "ok",
|
|
"timestamp": chrono::Utc::now(),
|
|
"service": "rustelo"
|
|
})))
|
|
}
|
|
|
|
/// API health check with database connectivity
|
|
async fn api_health_check(
|
|
State(state): State<AppState>,
|
|
) -> Result<Json<serde_json::Value>, StatusCode> {
|
|
// Check database connectivity
|
|
let db_status = match sqlx::query("SELECT 1")
|
|
.fetch_one(&state.auth_repository.pool)
|
|
.await
|
|
{
|
|
Ok(_) => "connected",
|
|
Err(_) => "disconnected",
|
|
};
|
|
|
|
Ok(Json(json!({
|
|
"status": "ok",
|
|
"timestamp": chrono::Utc::now(),
|
|
"service": "rustelo",
|
|
"database": db_status,
|
|
"rbac_enabled": state.rbac_service.is_enabled()
|
|
})))
|
|
}
|
|
|
|
/// Feature status endpoint
|
|
async fn feature_status(
|
|
State(state): State<AppState>,
|
|
) -> Result<Json<serde_json::Value>, StatusCode> {
|
|
Ok(Json(json!({
|
|
"auth": state.feature_config.auth.enabled,
|
|
"rbac": state.rbac_service.get_feature_status(),
|
|
"content": state.feature_config.content.enabled,
|
|
"security": {
|
|
"csrf": state.feature_config.security.csrf,
|
|
"rate_limiting": state.feature_config.security.rate_limiting
|
|
},
|
|
"performance": {
|
|
"caching": state.feature_config.performance.response_caching,
|
|
"compression": state.feature_config.performance.compression
|
|
}
|
|
})))
|
|
}
|
|
|
|
/// Authentication endpoints (simplified - integrate with full auth service)
|
|
async fn auth_login(
|
|
State(state): State<AppState>,
|
|
Json(credentials): Json<shared::auth::LoginCredentials>,
|
|
) -> Result<Json<serde_json::Value>, StatusCode> {
|
|
Ok(Json(json!({
|
|
"success": true,
|
|
"message": "Login successful",
|
|
"rbac_enabled": state.rbac_service.is_enabled()
|
|
})))
|
|
}
|
|
|
|
async fn auth_register(
|
|
State(state): State<AppState>,
|
|
Json(user_data): Json<shared::auth::RegisterUserData>,
|
|
) -> Result<Json<serde_json::Value>, StatusCode> {
|
|
Ok(Json(json!({
|
|
"success": true,
|
|
"message": "Registration successful"
|
|
})))
|
|
}
|
|
|
|
async fn auth_refresh(
|
|
State(state): State<AppState>,
|
|
Json(request): Json<shared::auth::RefreshTokenRequest>,
|
|
) -> Result<Json<serde_json::Value>, StatusCode> {
|
|
Ok(Json(json!({
|
|
"success": true,
|
|
"message": "Token refreshed"
|
|
})))
|
|
}
|
|
|
|
async fn auth_logout(State(state): State<AppState>) -> Result<Json<serde_json::Value>, StatusCode> {
|
|
Ok(Json(json!({
|
|
"success": true,
|
|
"message": "Logged out successfully"
|
|
})))
|
|
}
|
|
|
|
/// User profile endpoints
|
|
async fn get_user_profile(
|
|
State(state): State<AppState>,
|
|
) -> Result<Json<serde_json::Value>, StatusCode> {
|
|
Ok(Json(json!({
|
|
"profile": {
|
|
"username": "example_user",
|
|
"email": "user@example.com",
|
|
"roles": ["user"]
|
|
}
|
|
})))
|
|
}
|
|
|
|
async fn update_user_profile(
|
|
State(state): State<AppState>,
|
|
Json(profile): Json<shared::auth::UpdateUserData>,
|
|
) -> Result<Json<serde_json::Value>, StatusCode> {
|
|
Ok(Json(json!({
|
|
"success": true,
|
|
"message": "Profile updated successfully"
|
|
})))
|
|
}
|
|
|
|
/// Content endpoints (protected by conditional RBAC)
|
|
async fn get_content(
|
|
axum::extract::Path(content_id): axum::extract::Path<String>,
|
|
State(state): State<AppState>,
|
|
) -> Result<Json<serde_json::Value>, StatusCode> {
|
|
Ok(Json(json!({
|
|
"content_id": content_id,
|
|
"title": "Example Content",
|
|
"body": "This is example content...",
|
|
"protection": if state.rbac_service.is_feature_enabled("content_access") {
|
|
"RBAC Protected"
|
|
} else {
|
|
"Basic Role Protected"
|
|
}
|
|
})))
|
|
}
|
|
|
|
async fn update_content(
|
|
axum::extract::Path(content_id): axum::extract::Path<String>,
|
|
State(state): State<AppState>,
|
|
Json(content): Json<serde_json::Value>,
|
|
) -> Result<Json<serde_json::Value>, StatusCode> {
|
|
Ok(Json(json!({
|
|
"success": true,
|
|
"content_id": content_id,
|
|
"message": "Content updated successfully"
|
|
})))
|
|
}
|
|
|
|
/// Database endpoints (protected by conditional RBAC)
|
|
async fn get_database_info(
|
|
axum::extract::Path(db_name): axum::extract::Path<String>,
|
|
State(state): State<AppState>,
|
|
) -> Result<Json<serde_json::Value>, StatusCode> {
|
|
Ok(Json(json!({
|
|
"database": db_name,
|
|
"status": "accessible",
|
|
"protection": if state.rbac_service.is_feature_enabled("database_access") {
|
|
"RBAC Protected"
|
|
} else {
|
|
"Basic Role Protected"
|
|
}
|
|
})))
|
|
}
|
|
|
|
async fn execute_database_query(
|
|
axum::extract::Path(db_name): axum::extract::Path<String>,
|
|
State(state): State<AppState>,
|
|
Json(query): Json<serde_json::Value>,
|
|
) -> Result<Json<serde_json::Value>, StatusCode> {
|
|
Ok(Json(json!({
|
|
"success": true,
|
|
"database": db_name,
|
|
"message": "Query executed successfully"
|
|
})))
|
|
}
|
|
|
|
/// File endpoints (protected by conditional RBAC)
|
|
async fn read_file(
|
|
axum::extract::Path(file_path): axum::extract::Path<String>,
|
|
State(state): State<AppState>,
|
|
) -> Result<Json<serde_json::Value>, StatusCode> {
|
|
Ok(Json(json!({
|
|
"file_path": file_path,
|
|
"content": "File content here...",
|
|
"protection": if state.rbac_service.is_feature_enabled("file_access") {
|
|
"RBAC Protected"
|
|
} else {
|
|
"Basic Role Protected"
|
|
}
|
|
})))
|
|
}
|
|
|
|
async fn write_file(
|
|
axum::extract::Path(file_path): axum::extract::Path<String>,
|
|
State(state): State<AppState>,
|
|
Json(request): Json<serde_json::Value>,
|
|
) -> Result<Json<serde_json::Value>, StatusCode> {
|
|
Ok(Json(json!({
|
|
"success": true,
|
|
"file_path": file_path,
|
|
"message": "File written successfully"
|
|
})))
|
|
}
|
|
|
|
/// Admin endpoints (protected by conditional RBAC categories)
|
|
async fn list_users(State(state): State<AppState>) -> Result<Json<serde_json::Value>, StatusCode> {
|
|
Ok(Json(json!({
|
|
"users": [],
|
|
"protection": if state.rbac_service.is_feature_enabled("categories") {
|
|
"RBAC Category Protected (admin)"
|
|
} else {
|
|
"Basic Admin Role Protected"
|
|
}
|
|
})))
|
|
}
|
|
|
|
async fn get_user(
|
|
axum::extract::Path(user_id): axum::extract::Path<uuid::Uuid>,
|
|
State(state): State<AppState>,
|
|
) -> Result<Json<serde_json::Value>, StatusCode> {
|
|
Ok(Json(json!({
|
|
"user_id": user_id,
|
|
"user": {},
|
|
"message": "User retrieved successfully"
|
|
})))
|
|
}
|
|
|
|
async fn update_user(
|
|
axum::extract::Path(user_id): axum::extract::Path<uuid::Uuid>,
|
|
State(state): State<AppState>,
|
|
Json(user_data): Json<serde_json::Value>,
|
|
) -> Result<Json<serde_json::Value>, StatusCode> {
|
|
Ok(Json(json!({
|
|
"success": true,
|
|
"user_id": user_id,
|
|
"message": "User updated successfully"
|
|
})))
|
|
}
|
|
|
|
async fn delete_user(
|
|
axum::extract::Path(user_id): axum::extract::Path<uuid::Uuid>,
|
|
State(state): State<AppState>,
|
|
) -> Result<Json<serde_json::Value>, StatusCode> {
|
|
Ok(Json(json!({
|
|
"success": true,
|
|
"user_id": user_id,
|
|
"message": "User deleted successfully"
|
|
})))
|
|
}
|
|
|
|
/// CLI entry point for the RBAC-compatible server
|
|
#[tokio::main]
|
|
async fn main() -> Result<()> {
|
|
// Initialize tracing
|
|
tracing_subscriber::fmt::init();
|
|
|
|
// Load configuration
|
|
let config = Config::load().await?;
|
|
|
|
// Validate configuration
|
|
config.validate()?;
|
|
|
|
// Create and start server
|
|
let server = RBACCompatibleServer::new(config).await?;
|
|
server.start().await?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[tokio::test]
|
|
async fn test_server_creation() {
|
|
let config = Config::default();
|
|
// This would require a test database setup
|
|
// For now, just test that the structure compiles
|
|
assert!(true);
|
|
}
|
|
|
|
#[test]
|
|
fn test_conditional_middleware() {
|
|
let feature_config = Arc::new(FeatureConfig::default());
|
|
let rbac_service = ConditionalRBACService {
|
|
rbac_service: None,
|
|
rbac_repository: None,
|
|
feature_config,
|
|
};
|
|
|
|
// Test that middleware functions return None when RBAC is disabled
|
|
assert!(
|
|
rbac_service
|
|
.database_access_middleware("test".to_string(), "read".to_string())
|
|
.is_none()
|
|
);
|
|
assert!(
|
|
rbac_service
|
|
.file_access_middleware("test".to_string(), "read".to_string())
|
|
.is_none()
|
|
);
|
|
assert!(
|
|
rbac_service
|
|
.category_access_middleware(vec!["admin".to_string()])
|
|
.is_none()
|
|
);
|
|
}
|
|
}
|