284 lines
9.6 KiB
Rust
284 lines
9.6 KiB
Rust
|
|
use crate::error::{ControlCenterError, Result};
|
||
|
|
use crate::models::{JwtClaims, TokenResponse};
|
||
|
|
use anyhow::Context;
|
||
|
|
use jsonwebtoken::{
|
||
|
|
decode, encode, Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation,
|
||
|
|
};
|
||
|
|
// RSA key generation using rsa crate
|
||
|
|
use serde::{Deserialize, Serialize};
|
||
|
|
use std::sync::Arc;
|
||
|
|
use tracing::info;
|
||
|
|
use uuid::Uuid;
|
||
|
|
|
||
|
|
// Use the configuration from simple_config
|
||
|
|
use crate::simple_config::JwtConfig;
|
||
|
|
|
||
|
|
/// RSA key pair for JWT signing
|
||
|
|
pub struct RsaKeys {
|
||
|
|
pub private_key_pem: String,
|
||
|
|
pub public_key_pem: String,
|
||
|
|
}
|
||
|
|
|
||
|
|
/// JWT service for handling token operations with RS256
|
||
|
|
#[derive(Clone)]
|
||
|
|
pub struct JwtService {
|
||
|
|
config: JwtConfig,
|
||
|
|
encoding_key: Arc<EncodingKey>,
|
||
|
|
decoding_key: Arc<DecodingKey>,
|
||
|
|
validation: Validation,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl JwtService {
|
||
|
|
/// Create a new JWT service with RS256 algorithm
|
||
|
|
pub fn new(config: JwtConfig) -> Result<Self> {
|
||
|
|
let encoding_key = EncodingKey::from_rsa_pem(config.private_key_pem.as_bytes())
|
||
|
|
.context("Failed to create encoding key from private key PEM")?;
|
||
|
|
|
||
|
|
let decoding_key = DecodingKey::from_rsa_pem(config.public_key_pem.as_bytes())
|
||
|
|
.context("Failed to create decoding key from public key PEM")?;
|
||
|
|
|
||
|
|
let mut validation = Validation::new(Algorithm::RS256);
|
||
|
|
validation.set_issuer(&[&config.issuer]);
|
||
|
|
validation.set_audience(&[&config.audience]);
|
||
|
|
validation.validate_exp = true;
|
||
|
|
validation.validate_nbf = false;
|
||
|
|
|
||
|
|
info!("JWT service initialized with RS256 algorithm");
|
||
|
|
|
||
|
|
Ok(Self {
|
||
|
|
config,
|
||
|
|
encoding_key: Arc::new(encoding_key),
|
||
|
|
decoding_key: Arc::new(decoding_key),
|
||
|
|
validation,
|
||
|
|
})
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Generate access token
|
||
|
|
pub fn generate_access_token(
|
||
|
|
&self,
|
||
|
|
user_id: Uuid,
|
||
|
|
session_id: Uuid,
|
||
|
|
roles: Vec<String>,
|
||
|
|
) -> Result<String> {
|
||
|
|
let claims = JwtClaims::new(
|
||
|
|
user_id,
|
||
|
|
session_id,
|
||
|
|
roles,
|
||
|
|
self.config.access_token_expiration_hours,
|
||
|
|
self.config.issuer.clone(),
|
||
|
|
self.config.audience.clone(),
|
||
|
|
);
|
||
|
|
|
||
|
|
let header = Header::new(Algorithm::RS256);
|
||
|
|
|
||
|
|
encode(&header, &claims, &self.encoding_key)
|
||
|
|
.context("Failed to encode access token")
|
||
|
|
.map_err(ControlCenterError::from)
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Generate refresh token (longer-lived, simpler claims)
|
||
|
|
pub fn generate_refresh_token(&self, user_id: Uuid, session_id: Uuid) -> Result<String> {
|
||
|
|
let claims = RefreshTokenClaims::new(
|
||
|
|
user_id,
|
||
|
|
session_id,
|
||
|
|
self.config.refresh_token_expiration_hours,
|
||
|
|
self.config.issuer.clone(),
|
||
|
|
self.config.audience.clone(),
|
||
|
|
);
|
||
|
|
|
||
|
|
let header = Header::new(Algorithm::RS256);
|
||
|
|
|
||
|
|
encode(&header, &claims, &self.encoding_key)
|
||
|
|
.context("Failed to encode refresh token")
|
||
|
|
.map_err(ControlCenterError::from)
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Generate both access and refresh tokens
|
||
|
|
pub fn generate_token_pair(
|
||
|
|
&self,
|
||
|
|
user_id: Uuid,
|
||
|
|
session_id: Uuid,
|
||
|
|
roles: Vec<String>,
|
||
|
|
) -> Result<TokenResponse> {
|
||
|
|
let access_token = self.generate_access_token(user_id, session_id, roles)?;
|
||
|
|
let refresh_token = self.generate_refresh_token(user_id, session_id)?;
|
||
|
|
|
||
|
|
Ok(TokenResponse::new(
|
||
|
|
access_token,
|
||
|
|
refresh_token,
|
||
|
|
self.config.access_token_expiration_hours * 3600, // Convert to seconds
|
||
|
|
session_id.to_string(),
|
||
|
|
))
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Verify and decode access token
|
||
|
|
pub fn verify_access_token(&self, token: &str) -> Result<TokenData<JwtClaims>> {
|
||
|
|
decode::<JwtClaims>(token, &self.decoding_key, &self.validation)
|
||
|
|
.context("Failed to verify access token")
|
||
|
|
.map_err(ControlCenterError::from)
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Verify and decode refresh token
|
||
|
|
pub fn verify_refresh_token(&self, token: &str) -> Result<TokenData<RefreshTokenClaims>> {
|
||
|
|
decode::<RefreshTokenClaims>(token, &self.decoding_key, &self.validation)
|
||
|
|
.context("Failed to verify refresh token")
|
||
|
|
.map_err(ControlCenterError::from)
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Extract user ID from token without full verification (for expired tokens)
|
||
|
|
pub fn extract_user_id_unsafe(&self, token: &str) -> Result<Uuid> {
|
||
|
|
let mut validation = self.validation.clone();
|
||
|
|
validation.validate_exp = false; // Don't validate expiration
|
||
|
|
|
||
|
|
let token_data = decode::<JwtClaims>(token, &self.decoding_key, &validation)
|
||
|
|
.context("Failed to extract user ID from token")?;
|
||
|
|
|
||
|
|
Uuid::parse_str(&token_data.claims.sub)
|
||
|
|
.context("Invalid user ID in token")
|
||
|
|
.map_err(ControlCenterError::from)
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Get token expiration in seconds
|
||
|
|
pub fn get_access_token_expiration(&self) -> i64 {
|
||
|
|
self.config.access_token_expiration_hours * 3600
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Get refresh token expiration in seconds
|
||
|
|
pub fn get_refresh_token_expiration(&self) -> i64 {
|
||
|
|
self.config.refresh_token_expiration_hours * 3600
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Refresh token claims (simpler than access token)
|
||
|
|
#[derive(Debug, Serialize, Deserialize)]
|
||
|
|
pub struct RefreshTokenClaims {
|
||
|
|
pub sub: String, // Subject (user_id)
|
||
|
|
pub exp: i64, // Expiration timestamp
|
||
|
|
pub iat: i64, // Issued at timestamp
|
||
|
|
pub iss: String, // Issuer
|
||
|
|
pub aud: String, // Audience
|
||
|
|
pub session_id: String, // Session ID
|
||
|
|
pub token_type: String, // "refresh"
|
||
|
|
}
|
||
|
|
|
||
|
|
impl RefreshTokenClaims {
|
||
|
|
/// Create new refresh token claims
|
||
|
|
pub fn new(
|
||
|
|
user_id: Uuid,
|
||
|
|
session_id: Uuid,
|
||
|
|
expiration_hours: i64,
|
||
|
|
issuer: String,
|
||
|
|
audience: String,
|
||
|
|
) -> Self {
|
||
|
|
let now = chrono::Utc::now();
|
||
|
|
Self {
|
||
|
|
sub: user_id.to_string(),
|
||
|
|
exp: (now + chrono::Duration::hours(expiration_hours)).timestamp(),
|
||
|
|
iat: now.timestamp(),
|
||
|
|
iss: issuer,
|
||
|
|
aud: audience,
|
||
|
|
session_id: session_id.to_string(),
|
||
|
|
token_type: "refresh".to_string(),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Generate RSA key pair for JWT signing (RS256)
|
||
|
|
pub fn generate_rsa_key_pair() -> Result<RsaKeys> {
|
||
|
|
use rsa::{RsaPrivateKey, RsaPublicKey, pkcs1::EncodeRsaPrivateKey, pkcs1::EncodeRsaPublicKey};
|
||
|
|
use rand::rngs::OsRng;
|
||
|
|
|
||
|
|
// Generate 2048-bit RSA key pair
|
||
|
|
let mut rng = OsRng;
|
||
|
|
let private_key = RsaPrivateKey::new(&mut rng, 2048)
|
||
|
|
.context("Failed to generate RSA private key")?;
|
||
|
|
|
||
|
|
let public_key = RsaPublicKey::from(&private_key);
|
||
|
|
|
||
|
|
// Convert to PEM format
|
||
|
|
let private_key_pem = private_key.to_pkcs1_pem(rsa::pkcs1::LineEnding::LF)
|
||
|
|
.context("Failed to encode private key as PEM")?
|
||
|
|
.to_string();
|
||
|
|
|
||
|
|
let public_key_pem = public_key.to_pkcs1_pem(rsa::pkcs1::LineEnding::LF)
|
||
|
|
.context("Failed to encode public key as PEM")?;
|
||
|
|
|
||
|
|
Ok(RsaKeys {
|
||
|
|
private_key_pem,
|
||
|
|
public_key_pem,
|
||
|
|
})
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Load RSA key pair from files
|
||
|
|
pub fn load_rsa_keys_from_files(
|
||
|
|
private_key_path: &str,
|
||
|
|
public_key_path: &str,
|
||
|
|
) -> Result<RsaKeys> {
|
||
|
|
let private_key_pem = std::fs::read_to_string(private_key_path)
|
||
|
|
.with_context(|| format!("Failed to read private key from {}", private_key_path))?;
|
||
|
|
|
||
|
|
let public_key_pem = std::fs::read_to_string(public_key_path)
|
||
|
|
.with_context(|| format!("Failed to read public key from {}", public_key_path))?;
|
||
|
|
|
||
|
|
// Validate keys by creating encoding/decoding keys
|
||
|
|
EncodingKey::from_rsa_pem(private_key_pem.as_bytes())
|
||
|
|
.context("Invalid private key PEM format")?;
|
||
|
|
|
||
|
|
DecodingKey::from_rsa_pem(public_key_pem.as_bytes())
|
||
|
|
.context("Invalid public key PEM format")?;
|
||
|
|
|
||
|
|
Ok(RsaKeys {
|
||
|
|
private_key_pem,
|
||
|
|
public_key_pem,
|
||
|
|
})
|
||
|
|
}
|
||
|
|
|
||
|
|
#[cfg(test)]
|
||
|
|
mod tests {
|
||
|
|
use super::*;
|
||
|
|
|
||
|
|
#[tokio::test]
|
||
|
|
async fn test_jwt_token_generation_and_verification() {
|
||
|
|
let config = JwtConfig::default();
|
||
|
|
let jwt_service = JwtService::new(config).unwrap();
|
||
|
|
|
||
|
|
let user_id = Uuid::new_v4();
|
||
|
|
let session_id = Uuid::new_v4();
|
||
|
|
let roles = vec!["user".to_string(), "admin".to_string()];
|
||
|
|
|
||
|
|
// Generate tokens
|
||
|
|
let token_response = jwt_service.generate_token_pair(user_id, session_id, roles.clone()).unwrap();
|
||
|
|
|
||
|
|
// Verify access token
|
||
|
|
let access_claims = jwt_service.verify_access_token(&token_response.access_token).unwrap();
|
||
|
|
assert_eq!(access_claims.claims.sub, user_id.to_string());
|
||
|
|
assert_eq!(access_claims.claims.session_id, session_id.to_string());
|
||
|
|
assert_eq!(access_claims.claims.roles, roles);
|
||
|
|
|
||
|
|
// Verify refresh token
|
||
|
|
let refresh_claims = jwt_service.verify_refresh_token(&token_response.refresh_token).unwrap();
|
||
|
|
assert_eq!(refresh_claims.claims.sub, user_id.to_string());
|
||
|
|
assert_eq!(refresh_claims.claims.session_id, session_id.to_string());
|
||
|
|
assert_eq!(refresh_claims.claims.token_type, "refresh");
|
||
|
|
}
|
||
|
|
|
||
|
|
#[tokio::test]
|
||
|
|
async fn test_key_generation() {
|
||
|
|
let keys = generate_rsa_key_pair().unwrap();
|
||
|
|
|
||
|
|
// Keys should be valid PEM format
|
||
|
|
assert!(keys.private_key_pem.contains("-----BEGIN PRIVATE KEY-----"));
|
||
|
|
assert!(keys.private_key_pem.contains("-----END PRIVATE KEY-----"));
|
||
|
|
assert!(keys.public_key_pem.contains("-----BEGIN PUBLIC KEY-----"));
|
||
|
|
assert!(keys.public_key_pem.contains("-----END PUBLIC KEY-----"));
|
||
|
|
|
||
|
|
// Should be able to create JWT service with generated keys
|
||
|
|
let config = JwtConfig {
|
||
|
|
private_key_pem: keys.private_key_pem,
|
||
|
|
public_key_pem: keys.public_key_pem,
|
||
|
|
..JwtConfig::default()
|
||
|
|
};
|
||
|
|
|
||
|
|
let _jwt_service = JwtService::new(config).unwrap();
|
||
|
|
}
|
||
|
|
}
|