-1

I am checking out WebSocket frameworks for Rust and ended up watching a tutorial video on Warp (https://www.youtube.com/watch?v=fuiFycJpCBw), recreated that project and then compared it with Warp's own example implementation of a chat server.

I started to pick pieces from each approach and ended up modifying Warp's own example to my liking, then started to include some errors on purpose to see what effect it has on the code.

Specifically I was attempting to understand when which error handling branch would get executed.

These examples contain a main-scoped hashmap consisting of a mapping between the user id and their corresponding transmit channel, so that iterating this hashmap will allow to send a message to each connected user.

Each new connection will insert a new mapping via users.write().await.insert(my_id, tx); and upon disconnection remove it via users.write().await.remove(&my_id);.

What I'm doing in order to create a send error is to not remove the user mapping upon the client disconnecting. When then a new message comes in and this hashmap is iterated, it still contains the obsolete entry, trying to send a message through it, which correctly branches into the error branch for the send() attempt.

The issue is that this error branch is within a tokio::spawn block, and from within there I would like to issue this users.write().await.remove(&my_id); call which I removed from the normal flow.

I might be mistaken, but I believe that this is not possible, since I don't see a way for this task to access and modify this hashmap. If I understood the problem correctly, I am supposed to create an additional channel which this task can use to send a message back to the main scope in order for it to remove the entry from the hashmap.

For this I'm using an additional mpsc::unbounded_channel() on which I call the send method from the error handling branch in order to send the removal request message.

But this makes me also need to await a next() on the receiving end of the channel, which causes a problem, since that branch is already blocking in a while let Some(result) = user_rx.next().await loop block in order to wait for the next() incoming WebSocket message.

So what I tried to do was to add a tokio::select! block where I would listen for new WebSocket messages as well as those removal messages which are sent from the task when it encounters an error. This works, I can receive WebSocket messages as well as those from the new "control" channel.

Yet this creates a new problem: When the client disconnects, I would expect the tokio::select! block to trigger an error or something on ws_rx.next() (the WebSocket receiving socket), which is one of the branches in the tokio::select! block. This would allow me to treat that connection as disconnected and remove the client from the hashmap.

Previously, without the tokio::select! block, the while let Some(result) = ws_rx.next().await would terminate immediately as soon as a client disconnects, without raising an error.

What I also tried was instead of using an additional channel in order to send a request message back, to call drop(ws_tx), which didn't work. The core of the problem is that I want to be able to manipulate the hashmap from within that task.

I'm now adding the code, which can be copy pasted into a new project. It contains the two variants, one with the tokio::select! block and one with the while let Some(result) = user_rx.next().await block, they can be selected by setting the boolean from if true { /*select*/ } else { /*while let*/ }.

Two problems you want to inspect:

  1. when using the while let block, comment out the very last line users.write().await.remove(&current_id); to trigger the send error.
  2. when using the tokio::select! block, observe that select doesn't trigger a disconnection on the ws_rx.next() branch and therefore not reaching the bottom users.write().await.remove(&current_id);.

What I would like to do is to not use tokio::select! with a channel, but leave it at the simpler while let-variant, and modifying the users hashmap from within the tokio::task::spawn code.

Apparently I can use the hashmap there, but then I can't continue using it in the main scope.

This is the code which contains the problems, main.rs:


//###########################################################################
use std::collections::HashMap;
use std::sync::{atomic::{AtomicUsize, Ordering}, Arc};
use env_logger::Env;
use futures::{SinkExt, StreamExt};
use tokio::sync::{mpsc, RwLock};
use tokio_stream::wrappers::UnboundedReceiverStream;
use warp::ws::{Message, WebSocket};
use warp::Filter;
use colored::Colorize;
//###########################################################################


//###########################################################################
static NEXT_USER_ID: AtomicUsize = AtomicUsize::new(1);
type Users = Arc<RwLock<HashMap<usize, mpsc::UnboundedSender<Message>>>>;
//###########################################################################


//###########################################################################
#[tokio::main]
async fn main() {    
  env_logger::Builder::from_env(Env::default().default_filter_or("info")).init();
  let users = Users::default();
  let users = warp::any().map(move || users.clone());
  let websocket = warp::path("ws")
    .and(warp::ws())
    .and(users)
    .map(|ws: warp::ws::Ws, users| {
        ws.on_upgrade(move |socket| connect(socket, users))
    });
  let files = warp::fs::dir("./static");
  let port = 8186;
  println!("running server at 0.0.0.0:{}", port.to_string().yellow());
  warp::serve(files.or(websocket)).run(([0, 0, 0, 0], port)).await;
}
//###########################################################################


//###########################################################################
async fn connect(ws: WebSocket, users: Users) {
  let current_id = NEXT_USER_ID.fetch_add(1, Ordering::Relaxed);
  println!("user {} connected", current_id.to_string().green());
  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  let (mut ws_tx, mut ws_rx) = ws.split();
  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  let (mpsc_tx, mpsc_rx) = mpsc::unbounded_channel(); // For passing WS messages between tasks
  let mut mpsc_stream_rx = UnboundedReceiverStream::new(mpsc_rx);
  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  let (mpsc_tx_2, mpsc_rx_2) = mpsc::unbounded_channel(); // For sending `remove-request` messages
  let mut mpsc_stream_rx_2: UnboundedReceiverStream<(String, usize)> = UnboundedReceiverStream::new(mpsc_rx_2);
  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  tokio::task::spawn(async move {
    while let Some(message) = mpsc_stream_rx.next().await {
      //----------------------------------------------------------------
      match ws_tx.send(message).await {
        Ok(_) => {
          // println!("websocket send success (current_id={})", current_id);
        },
        Err(e) => {
          eprintln!("=============================================================");
          eprintln!("websocket send error (current_id={}): {}", current_id, e);
          eprintln!("=============================================================");
          mpsc_tx_2.send(("remove-user".to_string(), current_id)).expect("unable to send remove-user message");
          break;
        }
      };
      //----------------------------------------------------------------
    };
    // NOTE: Problem here: cannot use "users"
    // users.write().await.remove(&current_id);
    // eprintln!("websocket send task ended (current_id={})", current_id);
    // eprintln!("=============================================================");
  });
  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~


  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  users.write().await.insert(current_id, mpsc_tx);
  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

  
  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  if false { // <------------------ TOGGLE THIS
    loop {
      tokio::select! {
        Some(result) = ws_rx.next() => {
          //------------------------------------------------------------------
          let msg = match result {
            Ok(msg) => msg,
            Err(e) => {
              eprintln!("=============================================================");
              eprintln!("websocket receive error(current_id={}): {}", current_id, e);
              eprintln!("=============================================================");
              break;
            }
          };
          //------------------------------------------------------------------
          if let Ok(text) = msg.to_str() {
            //----------------------------------------------------------------
            println!("got message '{}' from user {}", text, current_id);
            let new_msg = Message::text(format!("user {}: {}", current_id, text));
            //----------------------------------------------------------------
            let mut remove = Vec::new();
            for (&uid, mpsc_tx) in users.read().await.iter() {
              if current_id != uid {
                println!(" -> forwarding message '{}' to channel of user {}", text, uid);
                if let Err(e) = mpsc_tx.send(new_msg.clone()) {
                  eprintln!("=============================================================");
                  eprintln!("websocket channel error (current_id={}, uid={}): {}", current_id, uid.clone(), e);
                  eprintln!("=============================================================");
                  remove.push(uid);
                }
              }
            }
            //----------------------------------------------------------------
            if remove.len() > 0 {
              for uid in remove {
                eprintln!("removing from users (uid={})", uid);
                eprintln!("=============================================================");
                users.write().await.remove(&uid);
              }
            }
            //----------------------------------------------------------------
          };
          //------------------------------------------------------------------
        }

        Some(result) = mpsc_stream_rx_2.next() => {
          let (command, uid) = result;
          if command == "remove-user" {
            eprintln!("=============================================================");
            eprintln!("removing user {}", uid);
            eprintln!("=============================================================");
            users.write().await.remove(&uid);
          }
          else {
            eprintln!("=============================================================");
            eprintln!("unknown command {}", command);
            eprintln!("=============================================================");
          }
          break;
        }
        else => break
      }
    }
  }
  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~


  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  else {
    while let Some(result) = ws_rx.next().await {
      //------------------------------------------------------------------
      let msg = match result {
        Ok(msg) => msg,
        Err(e) => {
          eprintln!("=============================================================");
          eprintln!("websocket receive error(current_id={}): {}", current_id, e);
          eprintln!("=============================================================");
          break;
        }
      };
      //------------------------------------------------------------------
      if let Ok(text) = msg.to_str() {
        //----------------------------------------------------------------
        println!("got message '{}' from user {}", text, current_id);
        let new_msg = Message::text(format!("user {}: {}", current_id, text));
        //----------------------------------------------------------------
        let mut remove = Vec::new();
        for (&uid, mpsc_tx) in users.read().await.iter() {
          if current_id != uid {
            println!(" -> forwarding message '{}' to channel of user {}", text, uid);
            if let Err(e) = mpsc_tx.send(new_msg.clone()) {
              eprintln!("=============================================================");
              eprintln!("websocket channel error (current_id={}, uid={}): {}", current_id, uid.clone(), e);
              eprintln!("=============================================================");
              remove.push(uid);
            }
          }
        }
        //----------------------------------------------------------------
        if remove.len() > 0 {
          for uid in remove {
            eprintln!("removing from users (uid={})", uid);
            eprintln!("=============================================================");
            users.write().await.remove(&uid);
          }
        }
        //----------------------------------------------------------------
      };
      //------------------------------------------------------------------
    }
  }
  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~


  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  println!("user {} disconnected", current_id.to_string().red());
  users.write().await.remove(&current_id);
  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
}
//###########################################################################


The source of this code is primarily based on these files:

https://github.com/seanmonstar/warp/blob/master/examples/websockets_chat.rs

https://github.com/ddprrt/warp-websockets-example/blob/main/src/main.rs

This is the Cargo.toml file content:

[package]
name = "websocket-3"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
tokio = { version = "1", features = ["full"] }
warp = "0.3.3"
tokio-stream = "0.1.10"
futures = "0.3.24"
env_logger = "0.9.1"
colored = "2"

And this is the index.html file residing in the static/ directory:

<!DOCTYPE html>
<html lang="en">
    <head>
        <style>
          html, body {
            color: rgba(128, 128, 128);
            background-color: rgb(32, 32, 32);
          }
        </style>
        <title>Warp Websocket 3 8186 Chat</title>
    </head>
    <body>
        <h1>Warp Websocket 3 8186 Chat</h1>
        <div id="chat">
            <p><em>Connecting...</em></p>
        </div>
        <input type="text" id="text" />
        <button type="button" id="send">Send</button>
        <script type="text/javascript">
        const chat = document.getElementById('chat');
        const text = document.getElementById('text');
        const uri = 'ws://' + location.host + '/ws';
        const ws = new WebSocket(uri);
        function message(data) {
            const line = document.createElement('p');
            line.innerText = data;
            chat.appendChild(line);
        }
        ws.onopen = function() {
            chat.innerHTML = '<p><em>Connected!</em></p>';
        };
        ws.onmessage = function(msg) {
            message(msg.data);
        };
        ws.onclose = function() {
            chat.getElementsByTagName('em')[0].innerText = 'Disconnected!';
        };
        send.onclick = function() {
            const msg = text.value;
            ws.send(msg);
            text.value = '';
            message('you: ' + msg);
        };
        </script>
    </body>
</html>```
Herohtar
  • 5,347
  • 4
  • 31
  • 41
Daniel F
  • 13,684
  • 11
  • 87
  • 116
  • This is a very long description for the actual question you are asking :D maybe reducing it down to a [minimal reproducible example](https://stackoverflow.com/help/minimal-reproducible-example) would be a good idea. – Finomnis Oct 03 '22 at 20:00
  • Would you mind elaborating what you mean with *"awaitable HashMap"*? The normal `std::collections::HashMap` is not awaitable. – Finomnis Oct 03 '22 at 20:04
  • I'm pretty new to this, but `type Users = Arc>>>;` is awaitable. – Daniel F Oct 03 '22 at 20:05
  • It's **asynchronously lockable**. That's a big difference. You can't await changes in it. – Finomnis Oct 03 '22 at 20:06
  • What I meant with this was that I am calling `users.write().await.remove(&current_id);` in order to be able to remove an id, so it's kind of an awaitable, or not? – Daniel F Oct 03 '22 at 20:08
  • One question: If I perform a `let users_in_spawn = users.clone()` before spawning and then use that `users_in_spawn` inside the `tokio::spawn`, will this affect the `users` hashmap? Because this apparently solves my problem. – Daniel F Oct 03 '22 at 20:10
  • Yes, that's exactly what you should do. That's actually the whole point of using an `Arc`. Use `Arc::clone(&users)` instead of `users.clone()`, though, to avoid ambiguity with `HashMap::clone()`. – Finomnis Oct 03 '22 at 20:12
  • *"so it's kind of an awaitable, or not"* - no, the locking action is awaitable (or rather: asynchronous). The HashMap itself cannot be awaited. *awaitable* usually means that you can do `users.await`. – Finomnis Oct 03 '22 at 20:13

1 Answers1

1

To be honest, I didn't read your entire question. It's a little too long.

Either way. I flew over it and stumbled across this:

Apparently I can use the hashmap there, but then I can't continue using it in the main scope.

This is incorrect. It is only true if you move the HashMap itself into the closure.

Arcs work a little different with move || closures: You have to clone them and then move the clone in:

async fn connect(ws: WebSocket, users: Users) {
    // .. some code ..
    tokio::task::spawn({
        let users = Arc::clone(&users);
        async move {
            // `users` in here is the cloned one,
            // the original one still exists
        }
    });
    // `users` can still be used here
}
Finomnis
  • 18,094
  • 1
  • 20
  • 27
  • This causes a `borrow of moved value: `users`` in subsequent uses of of `users`. But when I move that outside of the spawn block and rename the variable and use the renamed variable, then it works. – Daniel F Oct 03 '22 at 20:15
  • @DanielF I think you put it in the wrong spot. You have to put it **outside** of the `async move`. That should not move it. – Finomnis Oct 03 '22 at 20:15
  • Just before the `tokio::task::spawn` I call `let users_in_spawn = Arc::clone(&users);` and then use `users_in_spawn` in the spawned code. It also works without the `Arc::clone` by just calling `let users_in_spawn = users.clone()`. – Daniel F Oct 03 '22 at 20:17
  • Yes, it works, but it's bad practice. It could be mistaken with `HashMap::clone(&users)` in some circumstances. – Finomnis Oct 03 '22 at 20:19
  • *"Just before the `tokio::task::spawn`"* - Yes, but that requires a new variable name. If you add a new scope between the `spawn` and the `async move`, you can clone it there, shadow the original name, and you don't require a new name. It's cleaner, but of course it's personal preference. – Finomnis Oct 03 '22 at 20:20
  • 1
    Oh, OK, I missed that `async move {}`. Excellent! Thank you! – Daniel F Oct 03 '22 at 20:22