secretumvault/src/api/middleware.rs

94 lines
2.8 KiB
Rust
Raw Normal View History

2025-12-22 21:34:01 +00:00
/// API middleware for authentication and authorization
use axum::{
extract::{Request, State},
http::HeaderMap,
middleware::Next,
response::Response,
};
use std::sync::Arc;
use tracing::{error, warn};
use crate::auth::extract_bearer_token;
use crate::core::VaultCore;
/// Authentication middleware that validates Bearer tokens
pub async fn auth_middleware(
State(vault): State<Arc<VaultCore>>,
headers: HeaderMap,
request: Request,
next: Next,
) -> Response {
// System health endpoints don't require authentication
if request.uri().path().starts_with("/v1/sys/health")
|| request.uri().path().starts_with("/v1/sys/status")
|| request.uri().path().starts_with("/v1/sys/init")
{
return next.run(request).await;
}
// Check for bearer token
match extract_bearer_token(&headers) {
Some(token) => {
// Validate token
match vault.token_manager.validate(&token).await {
Ok(true) => {
// Token is valid, continue to next handler
next.run(request).await
}
Ok(false) => {
warn!("Invalid or expired token");
Response::builder()
.status(axum::http::StatusCode::UNAUTHORIZED)
.body(axum::body::Body::from("Invalid or expired token"))
.unwrap()
}
Err(e) => {
error!("Token validation error: {}", e);
Response::builder()
.status(axum::http::StatusCode::INTERNAL_SERVER_ERROR)
.body(axum::body::Body::from("Token validation failed"))
.unwrap()
}
}
}
None => {
warn!("Missing Authorization header");
Response::builder()
.status(axum::http::StatusCode::UNAUTHORIZED)
.body(axum::body::Body::from(
"Missing or invalid Authorization header",
))
.unwrap()
}
}
}
/// Request logging middleware
pub async fn logging_middleware(request: Request, next: Next) -> Response {
let method = request.method().clone();
let uri = request.uri().clone();
tracing::debug!("Request: {} {}", method, uri);
let response = next.run(request).await;
tracing::debug!("Response: {}", response.status());
response
}
#[cfg(test)]
mod tests {
#[test]
fn test_system_health_path() {
let path = "/v1/sys/health";
assert!(path.starts_with("/v1/sys/health"));
}
#[test]
fn test_system_status_path() {
let path = "/v1/sys/status";
assert!(path.starts_with("/v1/sys/status"));
}
}