321 lines
9.8 KiB
Rust
Raw Normal View History

2025-10-07 10:59:52 +01:00
use crate::error::{ControlCenterError, Result};
use axum::{
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
ConnectInfo, State,
},
response::Response,
};
use futures::{sink::SinkExt, stream::StreamExt};
use serde::{Deserialize, Serialize};
use std::{
collections::HashMap,
net::SocketAddr,
sync::Arc,
};
use tokio::sync::{broadcast, RwLock};
use tracing::{error, info, warn};
use uuid::Uuid;
/// WebSocket connection manager
pub struct WebSocketManager {
connections: Arc<RwLock<HashMap<Uuid, WebSocketConnection>>>,
event_tx: broadcast::Sender<WebSocketEvent>,
}
/// WebSocket connection info
#[derive(Debug, Clone)]
pub struct WebSocketConnection {
pub id: Uuid,
pub user_id: Option<Uuid>,
pub connected_at: chrono::DateTime<chrono::Utc>,
pub client_addr: Option<SocketAddr>,
}
/// WebSocket event types
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WebSocketEvent {
pub event_type: String,
pub data: serde_json::Value,
pub timestamp: chrono::DateTime<chrono::Utc>,
pub target_user: Option<Uuid>, // None for broadcast events
}
/// WebSocket message types
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum WebSocketMessage {
#[serde(rename = "auth")]
Auth { token: String },
#[serde(rename = "ping")]
Ping,
#[serde(rename = "pong")]
Pong,
#[serde(rename = "event")]
Event(WebSocketEvent),
#[serde(rename = "error")]
Error { message: String },
#[serde(rename = "connected")]
Connected { connection_id: Uuid },
}
impl WebSocketManager {
/// Create a new WebSocket manager
pub fn new() -> Self {
let (event_tx, _) = broadcast::channel(1000);
Self {
connections: Arc::new(RwLock::new(HashMap::new())),
event_tx,
}
}
/// Add a new connection
pub async fn add_connection(
&self,
id: Uuid,
user_id: Option<Uuid>,
client_addr: Option<SocketAddr>,
) {
let connection = WebSocketConnection {
id,
user_id,
connected_at: chrono::Utc::now(),
client_addr,
};
self.connections.write().await.insert(id, connection);
info!("WebSocket connection added: {}", id);
}
/// Remove a connection
pub async fn remove_connection(&self, id: Uuid) {
self.connections.write().await.remove(&id);
info!("WebSocket connection removed: {}", id);
}
/// Update connection with user authentication
pub async fn authenticate_connection(&self, id: Uuid, user_id: Uuid) {
if let Some(connection) = self.connections.write().await.get_mut(&id) {
connection.user_id = Some(user_id);
info!("WebSocket connection authenticated: {} for user: {}", id, user_id);
}
}
/// Broadcast event to all connections
pub async fn broadcast_event(&self, event: WebSocketEvent) {
if let Err(e) = self.event_tx.send(event) {
error!("Failed to broadcast WebSocket event: {}", e);
}
}
/// Send event to specific user
pub async fn send_to_user(&self, user_id: Uuid, event: WebSocketEvent) {
let mut event = event;
event.target_user = Some(user_id);
if let Err(e) = self.event_tx.send(event) {
error!("Failed to send WebSocket event to user {}: {}", user_id, e);
}
}
/// Get active connections count
pub async fn get_connections_count(&self) -> usize {
self.connections.read().await.len()
}
/// Get authenticated connections count
pub async fn get_authenticated_connections_count(&self) -> usize {
self.connections
.read()
.await
.values()
.filter(|conn| conn.user_id.is_some())
.count()
}
}
impl Default for WebSocketManager {
fn default() -> Self {
Self::new()
}
}
/// WebSocket handler
pub async fn websocket_handler(
ws: WebSocketUpgrade,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
State(app_state): State<Arc<crate::AppState>>,
) -> Response {
ws.on_upgrade(move |socket| handle_websocket(socket, addr, app_state.websocket_manager.clone()))
}
/// Handle individual WebSocket connection
async fn handle_websocket(
socket: WebSocket,
client_addr: SocketAddr,
ws_manager: Arc<WebSocketManager>,
) {
let connection_id = Uuid::new_v4();
// Add connection to manager
ws_manager
.add_connection(connection_id, None, Some(client_addr))
.await;
let (mut sender, mut receiver) = socket.split();
// Subscribe to events
let mut event_rx = ws_manager.event_tx.subscribe();
// Send connection confirmation
let connected_msg = WebSocketMessage::Connected { connection_id };
if let Ok(msg) = serde_json::to_string(&connected_msg) {
if let Err(e) = sender.send(Message::Text(msg.into())).await {
error!("Failed to send connection message: {}", e);
return;
}
}
let ws_manager_clone = ws_manager.clone();
let sender = Arc::new(tokio::sync::Mutex::new(sender));
let sender_for_events = sender.clone();
// Handle incoming messages
let incoming_task = tokio::spawn(async move {
while let Some(msg) = receiver.next().await {
match msg {
Ok(Message::Text(text)) => {
if let Err(e) = handle_websocket_message(
&text,
connection_id,
&ws_manager_clone,
sender.clone(),
)
.await
{
warn!("Error handling WebSocket message: {}", e);
}
}
Ok(Message::Close(_)) => {
info!("WebSocket connection closed: {}", connection_id);
break;
}
Ok(Message::Ping(data)) => {
if let Ok(mut sender) = sender.try_lock() {
if let Err(e) = sender.send(Message::Pong(data)).await {
error!("Failed to send pong: {}", e);
}
}
}
Err(e) => {
error!("WebSocket error: {}", e);
break;
}
_ => {}
}
}
});
// Handle outgoing events
let ws_manager_clone = ws_manager.clone();
let event_task = tokio::spawn(async move {
while let Ok(event) = event_rx.recv().await {
// Check if event is for this connection's user or is a broadcast
let should_send = if let Some(target_user) = event.target_user {
// Get current connection info to check user_id
if let Some(connection) = ws_manager_clone.connections.read().await.get(&connection_id) {
connection.user_id == Some(target_user)
} else {
false
}
} else {
true // Broadcast event
};
if should_send {
let message = WebSocketMessage::Event(event);
if let Ok(msg) = serde_json::to_string(&message) {
if let Ok(mut sender) = sender_for_events.try_lock() {
if let Err(e) = sender.send(Message::Text(msg.into())).await {
error!("Failed to send event message: {}", e);
break;
}
}
}
}
}
});
// Wait for either task to complete
tokio::select! {
_ = incoming_task => {},
_ = event_task => {},
}
// Clean up
ws_manager.remove_connection(connection_id).await;
}
/// Handle individual WebSocket message
async fn handle_websocket_message(
text: &str,
connection_id: Uuid,
_ws_manager: &WebSocketManager,
sender: Arc<tokio::sync::Mutex<futures::stream::SplitSink<WebSocket, Message>>>,
) -> Result<()> {
let message: WebSocketMessage = serde_json::from_str(text)
.map_err(|e| ControlCenterError::WebSocket(format!("Invalid message format: {}", e)))?;
match message {
WebSocketMessage::Auth { token: _ } => {
// Here you would validate the JWT token and authenticate the connection
// For now, we'll just log it
info!("Received auth message for connection: {}", connection_id);
// In a real implementation, you would:
// 1. Validate the JWT token
// 2. Extract user_id from token
// 3. Update connection with user_id
// ws_manager.authenticate_connection(connection_id, user_id).await;
}
WebSocketMessage::Ping => {
let pong_msg = WebSocketMessage::Pong;
if let Ok(msg) = serde_json::to_string(&pong_msg) {
if let Ok(mut sender) = sender.try_lock() {
sender.send(Message::Text(msg.into())).await.ok();
}
}
}
WebSocketMessage::Pong => {
// Handle pong response
}
_ => {
warn!("Received unexpected message type from connection: {}", connection_id);
}
}
Ok(())
}
/// Create system events for broadcasting
pub fn create_system_event(event_type: &str, data: serde_json::Value) -> WebSocketEvent {
WebSocketEvent {
event_type: event_type.to_string(),
data,
timestamp: chrono::Utc::now(),
target_user: None,
}
}
/// Create user-specific events
pub fn create_user_event(
event_type: &str,
data: serde_json::Value,
user_id: Uuid,
) -> WebSocketEvent {
WebSocketEvent {
event_type: event_type.to_string(),
data,
timestamp: chrono::Utc::now(),
target_user: Some(user_id),
}
}