321 lines
9.8 KiB
Rust
321 lines
9.8 KiB
Rust
use crate::error::{ControlCenterError, Result};
|
|
use axum::{
|
|
extract::{
|
|
ws::{Message, WebSocket, WebSocketUpgrade},
|
|
ConnectInfo, State,
|
|
},
|
|
response::Response,
|
|
};
|
|
use futures::{sink::SinkExt, stream::StreamExt};
|
|
use serde::{Deserialize, Serialize};
|
|
use std::{
|
|
collections::HashMap,
|
|
net::SocketAddr,
|
|
sync::Arc,
|
|
};
|
|
use tokio::sync::{broadcast, RwLock};
|
|
use tracing::{error, info, warn};
|
|
use uuid::Uuid;
|
|
|
|
/// WebSocket connection manager
|
|
pub struct WebSocketManager {
|
|
connections: Arc<RwLock<HashMap<Uuid, WebSocketConnection>>>,
|
|
event_tx: broadcast::Sender<WebSocketEvent>,
|
|
}
|
|
|
|
/// WebSocket connection info
|
|
#[derive(Debug, Clone)]
|
|
pub struct WebSocketConnection {
|
|
pub id: Uuid,
|
|
pub user_id: Option<Uuid>,
|
|
pub connected_at: chrono::DateTime<chrono::Utc>,
|
|
pub client_addr: Option<SocketAddr>,
|
|
}
|
|
|
|
/// WebSocket event types
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct WebSocketEvent {
|
|
pub event_type: String,
|
|
pub data: serde_json::Value,
|
|
pub timestamp: chrono::DateTime<chrono::Utc>,
|
|
pub target_user: Option<Uuid>, // None for broadcast events
|
|
}
|
|
|
|
/// WebSocket message types
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
#[serde(tag = "type")]
|
|
pub enum WebSocketMessage {
|
|
#[serde(rename = "auth")]
|
|
Auth { token: String },
|
|
#[serde(rename = "ping")]
|
|
Ping,
|
|
#[serde(rename = "pong")]
|
|
Pong,
|
|
#[serde(rename = "event")]
|
|
Event(WebSocketEvent),
|
|
#[serde(rename = "error")]
|
|
Error { message: String },
|
|
#[serde(rename = "connected")]
|
|
Connected { connection_id: Uuid },
|
|
}
|
|
|
|
impl WebSocketManager {
|
|
/// Create a new WebSocket manager
|
|
pub fn new() -> Self {
|
|
let (event_tx, _) = broadcast::channel(1000);
|
|
Self {
|
|
connections: Arc::new(RwLock::new(HashMap::new())),
|
|
event_tx,
|
|
}
|
|
}
|
|
|
|
/// Add a new connection
|
|
pub async fn add_connection(
|
|
&self,
|
|
id: Uuid,
|
|
user_id: Option<Uuid>,
|
|
client_addr: Option<SocketAddr>,
|
|
) {
|
|
let connection = WebSocketConnection {
|
|
id,
|
|
user_id,
|
|
connected_at: chrono::Utc::now(),
|
|
client_addr,
|
|
};
|
|
|
|
self.connections.write().await.insert(id, connection);
|
|
info!("WebSocket connection added: {}", id);
|
|
}
|
|
|
|
/// Remove a connection
|
|
pub async fn remove_connection(&self, id: Uuid) {
|
|
self.connections.write().await.remove(&id);
|
|
info!("WebSocket connection removed: {}", id);
|
|
}
|
|
|
|
/// Update connection with user authentication
|
|
pub async fn authenticate_connection(&self, id: Uuid, user_id: Uuid) {
|
|
if let Some(connection) = self.connections.write().await.get_mut(&id) {
|
|
connection.user_id = Some(user_id);
|
|
info!("WebSocket connection authenticated: {} for user: {}", id, user_id);
|
|
}
|
|
}
|
|
|
|
/// Broadcast event to all connections
|
|
pub async fn broadcast_event(&self, event: WebSocketEvent) {
|
|
if let Err(e) = self.event_tx.send(event) {
|
|
error!("Failed to broadcast WebSocket event: {}", e);
|
|
}
|
|
}
|
|
|
|
/// Send event to specific user
|
|
pub async fn send_to_user(&self, user_id: Uuid, event: WebSocketEvent) {
|
|
let mut event = event;
|
|
event.target_user = Some(user_id);
|
|
if let Err(e) = self.event_tx.send(event) {
|
|
error!("Failed to send WebSocket event to user {}: {}", user_id, e);
|
|
}
|
|
}
|
|
|
|
/// Get active connections count
|
|
pub async fn get_connections_count(&self) -> usize {
|
|
self.connections.read().await.len()
|
|
}
|
|
|
|
/// Get authenticated connections count
|
|
pub async fn get_authenticated_connections_count(&self) -> usize {
|
|
self.connections
|
|
.read()
|
|
.await
|
|
.values()
|
|
.filter(|conn| conn.user_id.is_some())
|
|
.count()
|
|
}
|
|
}
|
|
|
|
impl Default for WebSocketManager {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
/// WebSocket handler
|
|
pub async fn websocket_handler(
|
|
ws: WebSocketUpgrade,
|
|
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
|
State(app_state): State<Arc<crate::AppState>>,
|
|
) -> Response {
|
|
ws.on_upgrade(move |socket| handle_websocket(socket, addr, app_state.websocket_manager.clone()))
|
|
}
|
|
|
|
/// Handle individual WebSocket connection
|
|
async fn handle_websocket(
|
|
socket: WebSocket,
|
|
client_addr: SocketAddr,
|
|
ws_manager: Arc<WebSocketManager>,
|
|
) {
|
|
let connection_id = Uuid::new_v4();
|
|
|
|
// Add connection to manager
|
|
ws_manager
|
|
.add_connection(connection_id, None, Some(client_addr))
|
|
.await;
|
|
|
|
let (mut sender, mut receiver) = socket.split();
|
|
|
|
// Subscribe to events
|
|
let mut event_rx = ws_manager.event_tx.subscribe();
|
|
|
|
// Send connection confirmation
|
|
let connected_msg = WebSocketMessage::Connected { connection_id };
|
|
if let Ok(msg) = serde_json::to_string(&connected_msg) {
|
|
if let Err(e) = sender.send(Message::Text(msg.into())).await {
|
|
error!("Failed to send connection message: {}", e);
|
|
return;
|
|
}
|
|
}
|
|
|
|
let ws_manager_clone = ws_manager.clone();
|
|
let sender = Arc::new(tokio::sync::Mutex::new(sender));
|
|
let sender_for_events = sender.clone();
|
|
|
|
// Handle incoming messages
|
|
let incoming_task = tokio::spawn(async move {
|
|
while let Some(msg) = receiver.next().await {
|
|
match msg {
|
|
Ok(Message::Text(text)) => {
|
|
if let Err(e) = handle_websocket_message(
|
|
&text,
|
|
connection_id,
|
|
&ws_manager_clone,
|
|
sender.clone(),
|
|
)
|
|
.await
|
|
{
|
|
warn!("Error handling WebSocket message: {}", e);
|
|
}
|
|
}
|
|
Ok(Message::Close(_)) => {
|
|
info!("WebSocket connection closed: {}", connection_id);
|
|
break;
|
|
}
|
|
Ok(Message::Ping(data)) => {
|
|
if let Ok(mut sender) = sender.try_lock() {
|
|
if let Err(e) = sender.send(Message::Pong(data)).await {
|
|
error!("Failed to send pong: {}", e);
|
|
}
|
|
}
|
|
}
|
|
Err(e) => {
|
|
error!("WebSocket error: {}", e);
|
|
break;
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
});
|
|
|
|
// Handle outgoing events
|
|
let ws_manager_clone = ws_manager.clone();
|
|
let event_task = tokio::spawn(async move {
|
|
while let Ok(event) = event_rx.recv().await {
|
|
// Check if event is for this connection's user or is a broadcast
|
|
let should_send = if let Some(target_user) = event.target_user {
|
|
// Get current connection info to check user_id
|
|
if let Some(connection) = ws_manager_clone.connections.read().await.get(&connection_id) {
|
|
connection.user_id == Some(target_user)
|
|
} else {
|
|
false
|
|
}
|
|
} else {
|
|
true // Broadcast event
|
|
};
|
|
|
|
if should_send {
|
|
let message = WebSocketMessage::Event(event);
|
|
if let Ok(msg) = serde_json::to_string(&message) {
|
|
if let Ok(mut sender) = sender_for_events.try_lock() {
|
|
if let Err(e) = sender.send(Message::Text(msg.into())).await {
|
|
error!("Failed to send event message: {}", e);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
});
|
|
|
|
// Wait for either task to complete
|
|
tokio::select! {
|
|
_ = incoming_task => {},
|
|
_ = event_task => {},
|
|
}
|
|
|
|
// Clean up
|
|
ws_manager.remove_connection(connection_id).await;
|
|
}
|
|
|
|
/// Handle individual WebSocket message
|
|
async fn handle_websocket_message(
|
|
text: &str,
|
|
connection_id: Uuid,
|
|
_ws_manager: &WebSocketManager,
|
|
sender: Arc<tokio::sync::Mutex<futures::stream::SplitSink<WebSocket, Message>>>,
|
|
) -> Result<()> {
|
|
let message: WebSocketMessage = serde_json::from_str(text)
|
|
.map_err(|e| ControlCenterError::WebSocket(format!("Invalid message format: {}", e)))?;
|
|
|
|
match message {
|
|
WebSocketMessage::Auth { token: _ } => {
|
|
// Here you would validate the JWT token and authenticate the connection
|
|
// For now, we'll just log it
|
|
info!("Received auth message for connection: {}", connection_id);
|
|
|
|
// In a real implementation, you would:
|
|
// 1. Validate the JWT token
|
|
// 2. Extract user_id from token
|
|
// 3. Update connection with user_id
|
|
// ws_manager.authenticate_connection(connection_id, user_id).await;
|
|
}
|
|
WebSocketMessage::Ping => {
|
|
let pong_msg = WebSocketMessage::Pong;
|
|
if let Ok(msg) = serde_json::to_string(&pong_msg) {
|
|
if let Ok(mut sender) = sender.try_lock() {
|
|
sender.send(Message::Text(msg.into())).await.ok();
|
|
}
|
|
}
|
|
}
|
|
WebSocketMessage::Pong => {
|
|
// Handle pong response
|
|
}
|
|
_ => {
|
|
warn!("Received unexpected message type from connection: {}", connection_id);
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Create system events for broadcasting
|
|
pub fn create_system_event(event_type: &str, data: serde_json::Value) -> WebSocketEvent {
|
|
WebSocketEvent {
|
|
event_type: event_type.to_string(),
|
|
data,
|
|
timestamp: chrono::Utc::now(),
|
|
target_user: None,
|
|
}
|
|
}
|
|
|
|
/// Create user-specific events
|
|
pub fn create_user_event(
|
|
event_type: &str,
|
|
data: serde_json::Value,
|
|
user_id: Uuid,
|
|
) -> WebSocketEvent {
|
|
WebSocketEvent {
|
|
event_type: event_type.to_string(),
|
|
data,
|
|
timestamp: chrono::Utc::now(),
|
|
target_user: Some(user_id),
|
|
}
|
|
} |