use axum::{ extract::{ ws::{Message, WebSocket, WebSocketUpgrade}, State, }, response::IntoResponse, }; use futures::{sink::SinkExt, stream::StreamExt}; use leptos::LeptosOptions; use leptos_router::RouteListing; use tokio::sync::{broadcast, Mutex}; use std::{collections::HashSet, sync::Arc}; #[derive(Clone)] pub struct WebSocketState { pub client_set: Arc>>, pub tx: broadcast::Sender, } #[derive(Clone, axum::extract::FromRef)] pub struct AppState { pub leptos_options: LeptosOptions, pub websocket_state: Arc, pub routes: Vec, } pub async fn websocket_handler( ws: WebSocketUpgrade, State(state): State, ) -> impl IntoResponse { ws.on_upgrade(|socket| websocket(socket, Arc::new(state))) } async fn websocket(stream: WebSocket, state: Arc) { let state = &state.websocket_state; let (mut sender, mut receiver) = stream.split(); let mut client_set = state.client_set.lock().await; let uuid = uuid::Uuid::new_v4(); client_set.insert(uuid); drop(client_set); let mut rx = state.tx.subscribe(); let msg = format!("{uuid} joined"); println!("{uuid} joined"); let _ = state.tx.send(msg); let mut send_task = tokio::spawn(async move { while let Ok(msg) = rx.recv().await { if sender.send(Message::Text(msg)).await.is_err() { break; } } }); let tx = state.tx.clone(); let uuid_clone = uuid.clone(); let mut recv_task = tokio::spawn(async move { while let Some(Ok(Message::Text(text))) = receiver.next().await { let _ = tx.send(format!("{uuid_clone}: {text}")); } }); tokio::select! { _ = (&mut send_task) => recv_task.abort(), _ = (&mut recv_task) => send_task.abort(), }; let msg = format!("{uuid} left"); println!("{uuid} left"); let _ = state.tx.send(msg); state.client_set.lock().await.remove(&uuid); }