From 683233a4b54bdfe057fdee48d5383ec735cc3525 Mon Sep 17 00:00:00 2001 From: xeovalyte Date: Sat, 6 Apr 2024 17:24:27 +0200 Subject: [PATCH] Added websocket server --- application/Cargo.toml | 3 +- application/src/main.rs | 25 ++++++-- application/src/pages/add_participant.rs | 1 + application/src/util.rs | 1 + application/src/util/websocket.rs | 2 + application/src/util/websocket/server.rs | 80 ++++++++++++++++++++++++ 6 files changed, 107 insertions(+), 5 deletions(-) create mode 100644 application/src/util/websocket.rs create mode 100644 application/src/util/websocket/server.rs diff --git a/application/Cargo.toml b/application/Cargo.toml index 27a60c9..fdc4ebf 100644 --- a/application/Cargo.toml +++ b/application/Cargo.toml @@ -7,7 +7,7 @@ edition = "2021" crate-type = ["cdylib", "rlib"] [dependencies] -axum = { version = "0.7", optional = true } +axum = { version = "0.7", optional = true, features = [ "ws", "macros" ] } console_error_panic_hook = "0.1" leptos = { version = "0.6", features = [] } leptos_axum = { version = "0.6", optional = true } @@ -26,6 +26,7 @@ serde_json = "1.0.115" cfg-if = "1.0.0" once_cell = "1.19.0" futures = "0.3.30" +uuid = "1.8.0" [features] hydrate = ["leptos/hydrate", "leptos_meta/hydrate", "leptos_router/hydrate"] diff --git a/application/src/main.rs b/application/src/main.rs index 2bf13b2..1c8e0cc 100644 --- a/application/src/main.rs +++ b/application/src/main.rs @@ -1,11 +1,16 @@ +use application::util::websocket::{server::AppState, server::WebSocketState}; + #[cfg(feature = "ssr")] #[tokio::main] async fn main() { - use application::app::*; use application::fileserv::file_and_error_handler; - use axum::Router; + use application::{app::*, util::websocket::server}; + use axum::{routing::get, Router}; use leptos::*; use leptos_axum::{generate_route_list, LeptosRoutes}; + use std::collections::HashSet; + use std::sync::Arc; + use tokio::sync::{broadcast, Mutex}; application::util::surrealdb::connect() .await @@ -21,11 +26,23 @@ async fn main() { let addr = leptos_options.site_addr; let routes = generate_route_list(App); + let client_set = Arc::new(Mutex::new(HashSet::::new())); + let (tx, _rx) = broadcast::channel(100); + + let websocket_state = WebSocketState { client_set, tx }; + + let app_state = AppState { + websocket_state: websocket_state.into(), + routes: routes.clone(), + leptos_options, + }; + // build our application with a route let app = Router::new() - .leptos_routes(&leptos_options, routes, App) + .route("/ws", get(server::websocket_handler)) + .leptos_routes(&app_state, routes, App) .fallback(file_and_error_handler) - .with_state(leptos_options); + .with_state(app_state); let listener = tokio::net::TcpListener::bind(&addr).await.unwrap(); logging::log!("listening on http://{}", &addr); diff --git a/application/src/pages/add_participant.rs b/application/src/pages/add_participant.rs index 1b139cd..5fc28b4 100644 --- a/application/src/pages/add_participant.rs +++ b/application/src/pages/add_participant.rs @@ -4,6 +4,7 @@ use leptos_router::ActionForm; cfg_if::cfg_if! { if #[cfg(feature = "ssr")] { use crate::util::surrealdb::{DB, schemas}; + use crate::util::websocket::server; use leptos::logging; } } diff --git a/application/src/util.rs b/application/src/util.rs index 9538350..883ef16 100644 --- a/application/src/util.rs +++ b/application/src/util.rs @@ -1 +1,2 @@ pub mod surrealdb; +pub mod websocket; diff --git a/application/src/util/websocket.rs b/application/src/util/websocket.rs new file mode 100644 index 0000000..573b0e2 --- /dev/null +++ b/application/src/util/websocket.rs @@ -0,0 +1,2 @@ +#[cfg(feature = "ssr")] +pub mod server; diff --git a/application/src/util/websocket/server.rs b/application/src/util/websocket/server.rs new file mode 100644 index 0000000..15a9a61 --- /dev/null +++ b/application/src/util/websocket/server.rs @@ -0,0 +1,80 @@ +use axum::{ + extract::{ + ws::{Message, WebSocket, WebSocketUpgrade}, + State, + }, + response::IntoResponse, +}; +use futures::{sink::SinkExt, stream::StreamExt}; +use leptos::LeptosOptions; +use leptos_router::RouteListing; +use tokio::sync::{broadcast, Mutex}; + +use std::{collections::HashSet, sync::Arc}; + +#[derive(Clone)] +pub struct WebSocketState { + pub client_set: Arc>>, + pub tx: broadcast::Sender, +} + +#[derive(Clone, axum::extract::FromRef)] +pub struct AppState { + pub leptos_options: LeptosOptions, + pub websocket_state: Arc, + pub routes: Vec, +} + +pub async fn websocket_handler( + ws: WebSocketUpgrade, + State(state): State, +) -> impl IntoResponse { + ws.on_upgrade(|socket| websocket(socket, Arc::new(state))) +} + +async fn websocket(stream: WebSocket, state: Arc) { + let state = &state.websocket_state; + + let (mut sender, mut receiver) = stream.split(); + + let mut client_set = state.client_set.lock().await; + + let uuid = uuid::Uuid::new_v4(); + + client_set.insert(uuid); + drop(client_set); + + let mut rx = state.tx.subscribe(); + + let msg = format!("{uuid} joined"); + println!("{uuid} joined"); + let _ = state.tx.send(msg); + + let mut send_task = tokio::spawn(async move { + while let Ok(msg) = rx.recv().await { + if sender.send(Message::Text(msg)).await.is_err() { + break; + } + } + }); + + let tx = state.tx.clone(); + let uuid_clone = uuid.clone(); + + let mut recv_task = tokio::spawn(async move { + while let Some(Ok(Message::Text(text))) = receiver.next().await { + let _ = tx.send(format!("{uuid_clone}: {text}")); + } + }); + + tokio::select! { + _ = (&mut send_task) => recv_task.abort(), + _ = (&mut recv_task) => send_task.abort(), + }; + + let msg = format!("{uuid} left"); + println!("{uuid} left"); + let _ = state.tx.send(msg); + + state.client_set.lock().await.remove(&uuid); +}