use crate::error::{ControlCenterError, Result}; use axum::{ extract::{ ws::{Message, WebSocket, WebSocketUpgrade}, ConnectInfo, State, }, response::Response, }; use futures::{sink::SinkExt, stream::StreamExt}; use serde::{Deserialize, Serialize}; use std::{ collections::HashMap, net::SocketAddr, sync::Arc, }; use tokio::sync::{broadcast, RwLock}; use tracing::{error, info, warn}; use uuid::Uuid; /// WebSocket connection manager pub struct WebSocketManager { connections: Arc>>, event_tx: broadcast::Sender, } /// WebSocket connection info #[derive(Debug, Clone)] pub struct WebSocketConnection { pub id: Uuid, pub user_id: Option, pub connected_at: chrono::DateTime, pub client_addr: Option, } /// WebSocket event types #[derive(Debug, Clone, Serialize, Deserialize)] pub struct WebSocketEvent { pub event_type: String, pub data: serde_json::Value, pub timestamp: chrono::DateTime, pub target_user: Option, // None for broadcast events } /// WebSocket message types #[derive(Debug, Serialize, Deserialize)] #[serde(tag = "type")] pub enum WebSocketMessage { #[serde(rename = "auth")] Auth { token: String }, #[serde(rename = "ping")] Ping, #[serde(rename = "pong")] Pong, #[serde(rename = "event")] Event(WebSocketEvent), #[serde(rename = "error")] Error { message: String }, #[serde(rename = "connected")] Connected { connection_id: Uuid }, } impl WebSocketManager { /// Create a new WebSocket manager pub fn new() -> Self { let (event_tx, _) = broadcast::channel(1000); Self { connections: Arc::new(RwLock::new(HashMap::new())), event_tx, } } /// Add a new connection pub async fn add_connection( &self, id: Uuid, user_id: Option, client_addr: Option, ) { let connection = WebSocketConnection { id, user_id, connected_at: chrono::Utc::now(), client_addr, }; self.connections.write().await.insert(id, connection); info!("WebSocket connection added: {}", id); } /// Remove a connection pub async fn remove_connection(&self, id: Uuid) { self.connections.write().await.remove(&id); info!("WebSocket connection removed: {}", id); } /// Update connection with user authentication pub async fn authenticate_connection(&self, id: Uuid, user_id: Uuid) { if let Some(connection) = self.connections.write().await.get_mut(&id) { connection.user_id = Some(user_id); info!("WebSocket connection authenticated: {} for user: {}", id, user_id); } } /// Broadcast event to all connections pub async fn broadcast_event(&self, event: WebSocketEvent) { if let Err(e) = self.event_tx.send(event) { error!("Failed to broadcast WebSocket event: {}", e); } } /// Send event to specific user pub async fn send_to_user(&self, user_id: Uuid, event: WebSocketEvent) { let mut event = event; event.target_user = Some(user_id); if let Err(e) = self.event_tx.send(event) { error!("Failed to send WebSocket event to user {}: {}", user_id, e); } } /// Get active connections count pub async fn get_connections_count(&self) -> usize { self.connections.read().await.len() } /// Get authenticated connections count pub async fn get_authenticated_connections_count(&self) -> usize { self.connections .read() .await .values() .filter(|conn| conn.user_id.is_some()) .count() } } impl Default for WebSocketManager { fn default() -> Self { Self::new() } } /// WebSocket handler pub async fn websocket_handler( ws: WebSocketUpgrade, ConnectInfo(addr): ConnectInfo, State(app_state): State>, ) -> Response { ws.on_upgrade(move |socket| handle_websocket(socket, addr, app_state.websocket_manager.clone())) } /// Handle individual WebSocket connection async fn handle_websocket( socket: WebSocket, client_addr: SocketAddr, ws_manager: Arc, ) { let connection_id = Uuid::new_v4(); // Add connection to manager ws_manager .add_connection(connection_id, None, Some(client_addr)) .await; let (mut sender, mut receiver) = socket.split(); // Subscribe to events let mut event_rx = ws_manager.event_tx.subscribe(); // Send connection confirmation let connected_msg = WebSocketMessage::Connected { connection_id }; if let Ok(msg) = serde_json::to_string(&connected_msg) { if let Err(e) = sender.send(Message::Text(msg.into())).await { error!("Failed to send connection message: {}", e); return; } } let ws_manager_clone = ws_manager.clone(); let sender = Arc::new(tokio::sync::Mutex::new(sender)); let sender_for_events = sender.clone(); // Handle incoming messages let incoming_task = tokio::spawn(async move { while let Some(msg) = receiver.next().await { match msg { Ok(Message::Text(text)) => { if let Err(e) = handle_websocket_message( &text, connection_id, &ws_manager_clone, sender.clone(), ) .await { warn!("Error handling WebSocket message: {}", e); } } Ok(Message::Close(_)) => { info!("WebSocket connection closed: {}", connection_id); break; } Ok(Message::Ping(data)) => { if let Ok(mut sender) = sender.try_lock() { if let Err(e) = sender.send(Message::Pong(data)).await { error!("Failed to send pong: {}", e); } } } Err(e) => { error!("WebSocket error: {}", e); break; } _ => {} } } }); // Handle outgoing events let ws_manager_clone = ws_manager.clone(); let event_task = tokio::spawn(async move { while let Ok(event) = event_rx.recv().await { // Check if event is for this connection's user or is a broadcast let should_send = if let Some(target_user) = event.target_user { // Get current connection info to check user_id if let Some(connection) = ws_manager_clone.connections.read().await.get(&connection_id) { connection.user_id == Some(target_user) } else { false } } else { true // Broadcast event }; if should_send { let message = WebSocketMessage::Event(event); if let Ok(msg) = serde_json::to_string(&message) { if let Ok(mut sender) = sender_for_events.try_lock() { if let Err(e) = sender.send(Message::Text(msg.into())).await { error!("Failed to send event message: {}", e); break; } } } } } }); // Wait for either task to complete tokio::select! { _ = incoming_task => {}, _ = event_task => {}, } // Clean up ws_manager.remove_connection(connection_id).await; } /// Handle individual WebSocket message async fn handle_websocket_message( text: &str, connection_id: Uuid, _ws_manager: &WebSocketManager, sender: Arc>>, ) -> Result<()> { let message: WebSocketMessage = serde_json::from_str(text) .map_err(|e| ControlCenterError::WebSocket(format!("Invalid message format: {}", e)))?; match message { WebSocketMessage::Auth { token: _ } => { // Here you would validate the JWT token and authenticate the connection // For now, we'll just log it info!("Received auth message for connection: {}", connection_id); // In a real implementation, you would: // 1. Validate the JWT token // 2. Extract user_id from token // 3. Update connection with user_id // ws_manager.authenticate_connection(connection_id, user_id).await; } WebSocketMessage::Ping => { let pong_msg = WebSocketMessage::Pong; if let Ok(msg) = serde_json::to_string(&pong_msg) { if let Ok(mut sender) = sender.try_lock() { sender.send(Message::Text(msg.into())).await.ok(); } } } WebSocketMessage::Pong => { // Handle pong response } _ => { warn!("Received unexpected message type from connection: {}", connection_id); } } Ok(()) } /// Create system events for broadcasting pub fn create_system_event(event_type: &str, data: serde_json::Value) -> WebSocketEvent { WebSocketEvent { event_type: event_type.to_string(), data, timestamp: chrono::Utc::now(), target_user: None, } } /// Create user-specific events pub fn create_user_event( event_type: &str, data: serde_json::Value, user_id: Uuid, ) -> WebSocketEvent { WebSocketEvent { event_type: event_type.to_string(), data, timestamp: chrono::Utc::now(), target_user: Some(user_id), } }