/// API middleware for authentication and authorization use axum::{ extract::{Request, State}, http::HeaderMap, middleware::Next, response::Response, }; use std::sync::Arc; use tracing::{error, warn}; use crate::auth::extract_bearer_token; use crate::core::VaultCore; /// Authentication middleware that validates Bearer tokens pub async fn auth_middleware( State(vault): State>, headers: HeaderMap, request: Request, next: Next, ) -> Response { // System health endpoints don't require authentication if request.uri().path().starts_with("/v1/sys/health") || request.uri().path().starts_with("/v1/sys/status") || request.uri().path().starts_with("/v1/sys/init") { return next.run(request).await; } // Check for bearer token match extract_bearer_token(&headers) { Some(token) => { // Validate token match vault.token_manager.validate(&token).await { Ok(true) => { // Token is valid, continue to next handler next.run(request).await } Ok(false) => { warn!("Invalid or expired token"); Response::builder() .status(axum::http::StatusCode::UNAUTHORIZED) .body(axum::body::Body::from("Invalid or expired token")) .unwrap() } Err(e) => { error!("Token validation error: {}", e); Response::builder() .status(axum::http::StatusCode::INTERNAL_SERVER_ERROR) .body(axum::body::Body::from("Token validation failed")) .unwrap() } } } None => { warn!("Missing Authorization header"); Response::builder() .status(axum::http::StatusCode::UNAUTHORIZED) .body(axum::body::Body::from( "Missing or invalid Authorization header", )) .unwrap() } } } /// Request logging middleware pub async fn logging_middleware(request: Request, next: Next) -> Response { let method = request.method().clone(); let uri = request.uri().clone(); tracing::debug!("Request: {} {}", method, uri); let response = next.run(request).await; tracing::debug!("Response: {}", response.status()); response } #[cfg(test)] mod tests { #[test] fn test_system_health_path() { let path = "/v1/sys/health"; assert!(path.starts_with("/v1/sys/health")); } #[test] fn test_system_status_path() { let path = "/v1/sys/status"; assert!(path.starts_with("/v1/sys/status")); } }