81 lines
2.0 KiB
Rust
81 lines
2.0 KiB
Rust
use axum::{
|
|
extract::{
|
|
ws::{Message, WebSocket, WebSocketUpgrade},
|
|
State,
|
|
},
|
|
response::IntoResponse,
|
|
};
|
|
use futures::{sink::SinkExt, stream::StreamExt};
|
|
use leptos::LeptosOptions;
|
|
use leptos_router::RouteListing;
|
|
use tokio::sync::{broadcast, Mutex};
|
|
|
|
use std::{collections::HashSet, sync::Arc};
|
|
|
|
#[derive(Clone)]
|
|
pub struct WebSocketState {
|
|
pub client_set: Arc<Mutex<HashSet<uuid::Uuid>>>,
|
|
pub tx: broadcast::Sender<String>,
|
|
}
|
|
|
|
#[derive(Clone, axum::extract::FromRef)]
|
|
pub struct AppState {
|
|
pub leptos_options: LeptosOptions,
|
|
pub websocket_state: Arc<WebSocketState>,
|
|
pub routes: Vec<RouteListing>,
|
|
}
|
|
|
|
pub async fn websocket_handler(
|
|
ws: WebSocketUpgrade,
|
|
State(state): State<AppState>,
|
|
) -> impl IntoResponse {
|
|
ws.on_upgrade(|socket| websocket(socket, Arc::new(state)))
|
|
}
|
|
|
|
async fn websocket(stream: WebSocket, state: Arc<AppState>) {
|
|
let state = &state.websocket_state;
|
|
|
|
let (mut sender, mut receiver) = stream.split();
|
|
|
|
let mut client_set = state.client_set.lock().await;
|
|
|
|
let uuid = uuid::Uuid::new_v4();
|
|
|
|
client_set.insert(uuid);
|
|
drop(client_set);
|
|
|
|
let mut rx = state.tx.subscribe();
|
|
|
|
let msg = format!("{uuid} joined");
|
|
println!("{uuid} joined");
|
|
let _ = state.tx.send(msg);
|
|
|
|
let mut send_task = tokio::spawn(async move {
|
|
while let Ok(msg) = rx.recv().await {
|
|
if sender.send(Message::Text(msg)).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
});
|
|
|
|
let tx = state.tx.clone();
|
|
let uuid_clone = uuid.clone();
|
|
|
|
let mut recv_task = tokio::spawn(async move {
|
|
while let Some(Ok(Message::Text(text))) = receiver.next().await {
|
|
let _ = tx.send(format!("{uuid_clone}: {text}"));
|
|
}
|
|
});
|
|
|
|
tokio::select! {
|
|
_ = (&mut send_task) => recv_task.abort(),
|
|
_ = (&mut recv_task) => send_task.abort(),
|
|
};
|
|
|
|
let msg = format!("{uuid} left");
|
|
println!("{uuid} left");
|
|
let _ = state.tx.send(msg);
|
|
|
|
state.client_set.lock().await.remove(&uuid);
|
|
}
|