2

I would like to classify incoming tcp streams by their first n bytes and then pass to different handlers according to the classification.

I do not want to consume any of the bytes in the stream, otherwise I will be passing invalid streams to the handlers, that start with the nth byte.

So poll_peek looks almost like what I need, as it waits for data to be available before it peeks.

However I think what I would ideally need would be a poll_peek_exact that does not return until the passed buffer is full. This method does not seem to exist in TcpStream, so I'm not sure what the correct way would be to peek the first n bytes of a TcpStream without consuming them.

I could do something like:

    // Keep peeking until we have enough bytes to decide.
    while let Ok(num_bytes) = poll_fn(|cx| {
        tcp_stream.poll_peek(cx, &mut buf)
    }).await? {
        if num_bytes >= n {
            return classify(&buf);
        }
    }

But I think that would be busy waiting, so it seems like a bad idea, right? I could of course add a sleep to the loop, but that also does not seem like good style to me.

So what's the right way to do that?

talz
  • 1,004
  • 9
  • 22
  • Is it important that you pass the actual socket to the handlers? Otherwise you could wrap the entire stream in another layer that reads and forwards the header data, and pass that one. – Finomnis Dec 02 '22 at 11:15
  • It's not a busy loop because you have an `await` in there. Internally, a lot of async IO primitives work like this. – Peter Hall Dec 02 '22 at 11:19
  • @PeterHall, I thought after the first time `poll_peek` returns, it would return immediately in all the consecutive iterations, because some data is already there (even if not enough data for me). – talz Dec 02 '22 at 11:22
  • There is no such mechanism at the OS level. You need to actually consume the bytes, stored it into your own buffer until you have all together and then call whatever you want to do with these bytes. – Steffen Ullrich Dec 02 '22 at 11:28
  • @Finomnis, I probably don't have to pass the actual socket, but an object of a type that implements `AsyncRead + AsyncWrite + Unpin`. – talz Dec 02 '22 at 11:28
  • @SteffenUllrich is consuming the bytes better than a non-consuming `poll_peek` loop? – talz Dec 02 '22 at 11:31
  • @talz: if you consume the bytes you don't end up with a busy loop. – Steffen Ullrich Dec 02 '22 at 12:02

1 Answers1

2

Here is my attempt:

use pin_project::pin_project;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
use tokio::net::{TcpListener, TcpStream};

use std::error::Error;

#[pin_project]
struct HeaderExtractor<const S: usize> {
    #[pin]
    socket: TcpStream,
    header: [u8; S],
    num_forwarded: usize,
}

impl<const S: usize> HeaderExtractor<S> {
    pub async fn read_header(socket: TcpStream) -> Result<Self, Box<dyn Error>> {
        let mut this = Self {
            socket,
            header: [0; S],
            num_forwarded: 0,
        };

        this.socket.read_exact(&mut this.header).await?;

        Ok(this)
    }

    pub fn get_header(&mut self) -> &[u8; S] {
        &self.header
    }
}

impl<const S: usize> AsyncRead for HeaderExtractor<S> {
    fn poll_read(
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
        buf: &mut tokio::io::ReadBuf<'_>,
    ) -> std::task::Poll<std::io::Result<()>> {
        let this = self.project();

        if *this.num_forwarded < this.header.len() {
            let leftover = &this.header[*this.num_forwarded..];

            let num_forward_now = leftover.len().min(buf.remaining());
            let forward = &leftover[..num_forward_now];
            buf.put_slice(forward);

            *this.num_forwarded += num_forward_now;

            std::task::Poll::Ready(Ok(()))
        } else {
            this.socket.poll_read(cx, buf)
        }
    }
}

impl<const S: usize> AsyncWrite for HeaderExtractor<S> {
    fn poll_write(
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
        buf: &[u8],
    ) -> std::task::Poll<Result<usize, std::io::Error>> {
        let this = self.project();
        this.socket.poll_write(cx, buf)
    }

    fn poll_flush(
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Result<(), std::io::Error>> {
        let this = self.project();
        this.socket.poll_flush(cx)
    }

    fn poll_shutdown(
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Result<(), std::io::Error>> {
        let this = self.project();
        this.socket.poll_shutdown(cx)
    }
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
    let listener = TcpListener::bind("127.0.0.1:12345").await?;

    loop {
        // Asynchronously wait for an inbound socket.
        let (socket, _) = listener.accept().await?;

        let mut socket = HeaderExtractor::<3>::read_header(socket).await?;
        let header = socket.get_header();
        println!("Got header: {:?}", header);

        tokio::spawn(async move {
            let mut buf = vec![0; 1024];

            // In a loop, read data from the socket and write the data back.
            loop {
                let n = socket
                    .read(&mut buf)
                    .await
                    .expect("failed to read data from socket");

                if n == 0 {
                    println!("Connection closed.");
                    return;
                }

                println!("Received: {:?}", &buf[..n]);
            }
        });
    }
}

When I run echo "123HelloWorld!" | nc -N l localhost 12345 on another console, I get:

Got header: [49, 50, 51]
Received: [49, 50, 51]
Received: [72, 101, 108, 108, 111, 87, 111, 114, 108, 100, 33, 10]
Connection closed.
Finomnis
  • 18,094
  • 1
  • 20
  • 27