94 lines
2.8 KiB
Rust
94 lines
2.8 KiB
Rust
/// API middleware for authentication and authorization
|
|
use axum::{
|
|
extract::{Request, State},
|
|
http::HeaderMap,
|
|
middleware::Next,
|
|
response::Response,
|
|
};
|
|
use std::sync::Arc;
|
|
use tracing::{error, warn};
|
|
|
|
use crate::auth::extract_bearer_token;
|
|
use crate::core::VaultCore;
|
|
|
|
/// Authentication middleware that validates Bearer tokens
|
|
pub async fn auth_middleware(
|
|
State(vault): State<Arc<VaultCore>>,
|
|
headers: HeaderMap,
|
|
request: Request,
|
|
next: Next,
|
|
) -> Response {
|
|
// System health endpoints don't require authentication
|
|
if request.uri().path().starts_with("/v1/sys/health")
|
|
|| request.uri().path().starts_with("/v1/sys/status")
|
|
|| request.uri().path().starts_with("/v1/sys/init")
|
|
{
|
|
return next.run(request).await;
|
|
}
|
|
|
|
// Check for bearer token
|
|
match extract_bearer_token(&headers) {
|
|
Some(token) => {
|
|
// Validate token
|
|
match vault.token_manager.validate(&token).await {
|
|
Ok(true) => {
|
|
// Token is valid, continue to next handler
|
|
next.run(request).await
|
|
}
|
|
Ok(false) => {
|
|
warn!("Invalid or expired token");
|
|
Response::builder()
|
|
.status(axum::http::StatusCode::UNAUTHORIZED)
|
|
.body(axum::body::Body::from("Invalid or expired token"))
|
|
.unwrap()
|
|
}
|
|
Err(e) => {
|
|
error!("Token validation error: {}", e);
|
|
Response::builder()
|
|
.status(axum::http::StatusCode::INTERNAL_SERVER_ERROR)
|
|
.body(axum::body::Body::from("Token validation failed"))
|
|
.unwrap()
|
|
}
|
|
}
|
|
}
|
|
None => {
|
|
warn!("Missing Authorization header");
|
|
Response::builder()
|
|
.status(axum::http::StatusCode::UNAUTHORIZED)
|
|
.body(axum::body::Body::from(
|
|
"Missing or invalid Authorization header",
|
|
))
|
|
.unwrap()
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Request logging middleware
|
|
pub async fn logging_middleware(request: Request, next: Next) -> Response {
|
|
let method = request.method().clone();
|
|
let uri = request.uri().clone();
|
|
|
|
tracing::debug!("Request: {} {}", method, uri);
|
|
|
|
let response = next.run(request).await;
|
|
|
|
tracing::debug!("Response: {}", response.status());
|
|
|
|
response
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
#[test]
|
|
fn test_system_health_path() {
|
|
let path = "/v1/sys/health";
|
|
assert!(path.starts_with("/v1/sys/health"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_system_status_path() {
|
|
let path = "/v1/sys/status";
|
|
assert!(path.starts_with("/v1/sys/status"));
|
|
}
|
|
}
|