6

I am trying to create a server that can accept both secure SSL and insecure plain text connection (for backwards compatibility). My code is almost working except the first transmitted data received from an insecure client loses the first 5 bytes (chars) at the server. More specificially if I transmit 30 bytes on an insecure connection, when the server gets to the OnClientDataReceived() function, the line "int iRx = nwStream.EndRead(asyn);", then iRx = 25. Any subsequent messages transmitted from the client contain all sent bytes/chars as expected. I suspect that the initial assumption of the connection being an SSLStream may be stripping the first 5 bytes and then when it fails, those 5 bytes have already been extracted from the buffer and are no longer available. Does any body know of another approach I could take to write the code so that the server automatically can switch on the fly?

I am trying to avoid doing the following:

  • Require that a client connect using a plain text NetworkStream and then request to upgrade to an SSL stream
  • Setting up two TcpListeners on two different ports (one for secure, one for insecure)

Here is my code:

/// Each client that connects gets an instance of the ConnectedClient class.
Class Pseudo_ConnectedClient
{
    //Properties
    byte[] Buffer; //Holds temporary buffer of read bytes from BeginRead()
    TcpClient TCPClient; //Reference to the connected client
    Socket ClientSocket; //The outer Socket Reference of the connected client
    StringBuilder CurrentMessage; //concatenated chunks of data in buffer until we have a complete message (ends with <ETX>
    Stream Stream; //SSLStream or NetworkStream depending on client
    ArrayList MessageQueue; //Array of complete messages received from client that need to be processed
}

/// When a new client connects (OnClientConnection callback is executed), the server creates the ConnectedClient object and stores its 
/// reference in a local dictionary, then configures the callbacks for incoming data (WaitForClientData)
void OnClientConnection(IAsyncResult result)
{
    TcpListener listener = result.AsyncState as TcpListener;
    TcpClient clnt = null;

    try
    {
        if (!IsRunning) //then stop was called, so don't call EndAcceptTcpClient because it will throw and ObjectDisposedException
            return;

        //Start accepting the next connection...
        listener.BeginAcceptTcpClient(this.onClientConnection, listener);

        //Get reference to client and set flag to indicate connection accepted.
        clnt = listener.EndAcceptTcpClient(result);

        //Add the reference to our ArrayList of Connected Clients
        ConnectedClient conClnt = new ConnectedClient(clnt);
        _clientList.Add(conClnt);

        //Configure client to listen for incoming data
        WaitForClientData(conClnt);
    }
    catch (Exception ex)
    {
        Trace.WriteLine("Server:OnClientConnection: Exception - " + ex.ToString());
    }
}

/// WaitForClientData registers the AsyncCallback to handle incoming data from a client (OnClientDataReceieved).  
/// If a certificate has been provided, then it listens for clients to connect on an SSLStream and configures the 
/// BeginAuthenticateAsServer callback.  If no certificate is provided, then it only sets up a NetworkStream 
/// and prepares for the BeginRead callback.
private void WaitForClientData(ConnectedClient clnt)
{
    if (!IsRunning) return; //Then stop was called, so don't do anything

    SslStream sslStream = null;

    try
    {
        if (_pfnClientDataCallBack == null) //then define the call back function to invoke when data is received from a connected client
            _pfnClientDataCallBack = new AsyncCallback(OnClientDataReceived);

        NetworkStream nwStream = clnt.TCPClient.GetStream();

        //Check if we can establish a secure connection
        if (this.SSLCertificate != null) //Then we have the ability to make an SSL connection (SSLCertificate is a X509Certificate2 object)
        {
            if (this.certValidationCallback != null)
                sslStream = new SslStream(nwStream, true, this.certValidationCallback);
            else
                sslStream = new SslStream(nwStream, true);

            clnt.Stream = sslStream;

            //Start Listening for incoming (secure) data
            sslStream.BeginAuthenticateAsServer(this.SSLCertificate, false, SslProtocols.Default, false, onAuthenticateAsServer, clnt);
        }
        else //No certificate available to make a secure connection, so use insecure (unless not allowed)
        {
            if (this.RequireSecureConnection == false) //Then we can try to read from the insecure stream
            {
                clnt.Stream = nwStream;

                //Start Listening for incoming (unsecure) data
                nwStream.BeginRead(clnt.Buffer, 0, clnt.Buffer.Length, _pfnClientDataCallBack, clnt);
            }
            else //we can't do anything - report config problem
            {
                throw new InvalidOperationException("A PFX certificate is not loaded and the server is configured to require a secure connection");
            }
        }
    }
    catch (Exception ex)
    {
        DisconnectClient(clnt);
    }
}

/// OnAuthenticateAsServer first checks if the stream is authenticated, if it isn't it gets the TCPClient's reference 
/// to the outer NetworkStream (client.TCPClient.GetStream()) - the insecure stream and calls the BeginRead on that.  
/// If the stream is authenticated, then it keeps the reference to the SSLStream and calls BeginRead on it.
private void OnAuthenticateAsServer(IAsyncResult result)
{
    ConnectedClient clnt = null;
    SslStream sslStream = null;

    if (this.IsRunning == false) return;

    try
    {
        clnt = result.AsyncState as ConnectedClient;
        sslStream = clnt.Stream as SslStream;

        if (sslStream.IsAuthenticated)
            sslStream.EndAuthenticateAsServer(result);
        else //Try and switch to an insecure connections
        {
            if (this.RequireSecureConnection == false) //Then we are allowed to accept insecure connections
            {
                if (clnt.TCPClient.Connected)
                    clnt.Stream = clnt.TCPClient.GetStream();
            }
            else //Insecure connections are not allowed, close the connection
            {
                DisconnectClient(clnt);
            }
        }
    }
    catch (Exception ex)
    {
        DisconnectClient(clnt);
    }

    if( clnt.Stream != null) //Then we have a stream to read, start Async read
        clnt.Stream.BeginRead(clnt.Buffer, 0, clnt.Buffer.Length, _pfnClientDataCallBack, clnt);
}

/// OnClientDataReceived callback is triggered by the BeginRead async when data is available from a client.  
/// It determines if the stream (as assigned by OnAuthenticateAsServer) is an SSLStream or a NetworkStream 
/// and then reads the data out of the stream accordingly.  The logic to parse and process the message has 
/// been removed because it isn't relevant to the question.
private void OnClientDataReceived(IAsyncResult asyn)
{
    try
    {
        ConnectedClient connectClnt = asyn.AsyncState as ConnectedClient;

        if (!connectClnt.TCPClient.Connected) //Then the client is no longer connected >> clean up
        {
            DisconnectClient(connectClnt);
            return;
        }

        Stream nwStream = null;
        if( connectClnt.Stream is SslStream) //Then this client is connected via a secure stream
            nwStream = connectClnt.Stream as SslStream;
        else //this is a plain text stream
            nwStream = connectClnt.Stream as NetworkStream;

        // Complete the BeginReceive() asynchronous call by EndReceive() method which
        // will return the number of characters written to the stream by the client
        int iRx = nwStream.EndRead(asyn); //Returns the numbers of bytes in the read buffer
        char[] chars = new char[iRx];   

        // Extract the characters as a buffer and create a String
        Decoder d = ASCIIEncoding.UTF8.GetDecoder();
        d.GetChars(connectClnt.Buffer, 0, iRx, chars, 0);

        //string data = ASCIIEncoding.ASCII.GetString(buff, 0, buff.Length);
        string data = new string(chars);

        if (iRx > 0) //Then there was data in the buffer
        {
            //Append the current packet with any additional data that was already received
            connectClnt.CurrentMessage.Append(data);

            //Do work here to check for a complete message
            //Make sure two complete messages didn't get concatenated in one transmission (mobile devices)
            //Add each message to the client's messageQueue
            //Clear the currentMessage
            //Any partial messsage at the end of the buffer needs to be added to the currentMessage

            //Start reading again
            nwStream.BeginRead(connectClnt.Buffer, 0, connectClnt.Buffer.Length, OnClientDataReceived, connectClnt);
        }
        else //zero-length packet received - Disconnecting socket
        {
            DisconnectClient(connectClnt);
        }                
    }
    catch (Exception ex)
    {
        return;
    }
}

What works:

  • If the server doesn't have a certificate, a NetworkStream is only used, and all bytes are received from the client for all messages.
  • If the server does have a certificate (an SSLStream is setup) and a secure connection can be established (web-browser using https://) and the full message is received for all messages.

What doesn't work:

  • If the server does have a certificate (an SSLStream is setup) and an insecure connection is made from a client, when the first message is received from that client, the code does correctly detect the SSLStream is not authenticated and switches to the NetworkStream of the TCPClient. However, when EndRead is called on that NetworkStream for the first message, the first 5 chars (bytes) are missing from the message that was sent, but only for the first message. All remaining messages are complete as long as the TCPClient is connected. If the client disconnects and then reconnects, the first message is clipped, then all subsequent messages are good again.

What is causing those first 5 bytes to be clipped, and how can I avoid it?

My project is currently using .NET v3.5... I would like to remain at this version and not step up to 4.0 if I can avoid it.


Follow-up Question

Damien's answer below does allow me to retain those missing 5 bytes, however, I would prefer to stick with the BeginRead and EndRead methods in my code to avoid blocking. Are there any good tutorials showing a 'best practices' when override(ing) these? More specifically, how to work with the IAsyncResult object. I get that I would need to add any content that is stored in the RestartableStream buffers, then fall through to the inner stream (base) to get the rest and return the toral. But since the IAsyncResult object is a custom class, I can't figure out the generic way that I can combine the buffs of RestartableStream with those of the inner stream before returning. Do I need to also implement BeginRead() so that I know the buffers the user wants the content stored into? I guess the other solution is, since the dropped bytes problem is only with the first message from the client (after that I know whether to use it as a SSLStream or a NetworkStream), would be to handle that first message by directly calling the Read() method of RestartableStream (temporarily blocking the code), then for all future messages use the Async callbacks to Read the contents as I do now.

Jerren Saunders
  • 1,188
  • 1
  • 8
  • 26
  • 2
    why don't you want to open another port? thats what http(s) does. – Daniel A. White Mar 05 '13 at 15:28
  • Getting a port open on our companies firewall is a long process. We have one port open now and would like to use the same port for both secure and insecure connections if possible. – Jerren Saunders Mar 05 '13 at 15:42
  • What you're trying to do is sometimes referred to as "port unification". It's quite uncommon, and often confusing. Another way would be to change your protocol to have a "STARTTLS" command; a number of protocols do that. – Bruno Mar 06 '13 at 20:52

2 Answers2

3

Okay, I think the best you can do is place your own class in between SslStream and NetworkStream where you implement some customized buffering. I've done a few tests on the below, but I'd recommend a few more before you put in in production (and probably some more robust error handling). I think I've avoided any 4.0 or 4.5isms:

  public sealed class RestartableReadStream : Stream
  {
    private Stream _inner;
    private List<byte[]> _buffers;
    private bool _buffering;
    private int? _currentBuffer = null;
    private int? _currentBufferPosition = null;
    public RestartableReadStream(Stream inner)
    {
      if (!inner.CanRead) throw new NotSupportedException(); //Don't know what else is being expected of us
      if (inner.CanSeek) throw new NotSupportedException(); //Just use the underlying streams ability to seek, no need for this class
      _inner = inner;
      _buffering = true;
      _buffers = new List<byte[]>();
    }

    public void StopBuffering()
    {
      _buffering = false;
      if (!_currentBuffer.HasValue)
      {
        //We aren't currently using the buffers
        _buffers = null;
        _currentBufferPosition = null;
      }
    }

    public void Restart()
    {
      if (!_buffering) throw new NotSupportedException();  //Buffering got turned off already
      if (_buffers.Count == 0) return;
      _currentBuffer = 0;
      _currentBufferPosition = 0;
    }

    public override int Read(byte[] buffer, int offset, int count)
    {
      if (_currentBuffer.HasValue)
      {
        //Try to satisfy the read request from the current buffer
        byte[] rbuffer = _buffers[_currentBuffer.Value];
        int roffset = _currentBufferPosition.Value;
        if ((rbuffer.Length - roffset) <= count)
        {
          //Just give them what we have in the current buffer (exhausting it)
          count = (rbuffer.Length - roffset);
          for (int i = 0; i < count; i++)
          {
            buffer[offset + i] = rbuffer[roffset + i];
          }

          _currentBuffer++;
          if (_currentBuffer.Value == _buffers.Count)
          {
            //We've stopped reading from the buffers
            if (!_buffering)
              _buffers = null;
            _currentBuffer = null;
            _currentBufferPosition = null;
          }
          return count;
        }
        else
        {
          for (int i = 0; i < count; i++)
          {
            buffer[offset + i] = rbuffer[roffset + i];
          }
          _currentBufferPosition += count;
          return count;
        }
      }
      //If we reach here, we're currently using the inner stream. But may be buffering the results
      int ncount = _inner.Read(buffer, offset, count);
      if (_buffering)
      {
        byte[] rbuffer = new byte[ncount];
        for (int i = 0; i < ncount; i++)
        {
          rbuffer[i] = buffer[offset + i];
        }
        _buffers.Add(rbuffer);
      }
      return ncount;
    }

    public override bool CanRead
    {
      get { return true; }
    }

    public override bool CanSeek
    {
      get { return false; }
    }

    public override bool CanWrite
    {
      get { return false; }
    }

    //No more interesting code below here

    public override void Flush()
    {
      throw new NotSupportedException();
    }

    public override long Length
    {
      get { throw new NotSupportedException(); }
    }

    public override long Position
    {
      get
      {
        throw new NotSupportedException();
      }
      set
      {
        throw new NotSupportedException();
      }
    }

    public override long Seek(long offset, SeekOrigin origin)
    {
      throw new NotSupportedException();
    }

    public override void SetLength(long value)
    {
      throw new NotSupportedException();
    }

    public override void Write(byte[] buffer, int offset, int count)
    {
      throw new NotSupportedException();
    }
  }

Usage:

Construct a RestartableReadStream around your NetworkStream. Pass that instance to SslStream. If you decide that SSL was the wrong way to do things, call Restart() and then use it again however you want to. You can even try more than two strategies (calling Restart() between each one).

Once you've settled on which strategy (e.g. SSL or non-SSL) is correct, call StopBuffering(). Once it's finished replaying any buffers it had available, it will revert to just calling Read on its inner stream. If you don't call StopBuffering, then the entire history of reads from the stream will be kept in the _buffers list, which could add quite a bit of memory pressure.

Note that none of the above particularly accounts for multi-threaded access. But if you've got multiple threads calling Read() on a single stream (especially one that's network based), I wouldn't expect any sanity anyway.

Damien_The_Unbeliever
  • 234,701
  • 27
  • 340
  • 448
  • I implemented your RestartableStream class and it is definitely retaining the lost 5 bytes! Thank you!!! I have two follow up questions though... since I'm new to overflow, do I ask them in a comment to your question, or start a new question and somehow link them to this one? – Jerren Saunders Mar 06 '13 at 17:26
  • @JerrenSaunders - if it's just a minor addition to what's already been asked, and still directly related to the original question, you should ideally edit your question (there's an edit link at the bottom) and add your follow-up questions there. I'd usually say if it means more than adding a paragraph or two though, or is less related to the original question, it may be better to post as a new question. Unfortunately, there's no "one size fits all" advice to give on this, and it comes down to how related and how close to the original, so far as I'm concerned. – Damien_The_Unbeliever Mar 06 '13 at 17:44
  • @JerrenSaunders - also, if you feel the answer has helped, please upvote it. If it answers your question entirely (it doesn't sound like it does yet), please hit the check mark next to it. – Damien_The_Unbeliever Mar 06 '13 at 17:47
  • It almost does. One question is more appropriately asked as a comment (the other one I'll put in as a followup): An SslStream() requires that the given stream be read/write. I notice that you forced the CanWrite property to false for your class. If I verify that the connection is an SSLStream, and use the SSLStream.Write() for those clients, and use the RestartableStream.Write()=>base.Write() for the clients that are just a regular NetworkStream (plain-text), are there any problems that I may be overlooking that you were trying to avoid? (I have a bad cold today, so I'm a little foggy). – Jerren Saunders Mar 06 '13 at 17:49
  • @JerrenSaunders - you can modify the above class to allow `Write()` calls to occur - but I'd recommend that a) you raise an exception if you're currently replaying (e.g. `_currentBuffer.HasValue` is true), and b) That you set `_buffering` false. (And then modify my code to call `_inner.CanWrite` inside my `CanWrite`). I can't point to any specific problems that will occur without doing these, it just feels slightly spooky if something is `Write`ing during a replay of parts of the stream that have already been read at least once before. – Damien_The_Unbeliever Mar 06 '13 at 17:58
  • OK. That makes sense. In my case, I only intend to write (respond) to the client (via the stream) once the server has determined the correct stream type. So by that point the StopBuffering method has been called and the buffers cleared. But good point to safe-guard against those situations in case this class gets re-used down the road. – Jerren Saunders Mar 06 '13 at 18:04
  • Bad news. I was so focused on recovering those missing 5 bytes for the insecure connections, that I didn't retest the secure connection. Just before the SSLStream OnAuthenticateAsServer callback is executed, the RestartableStream Write method is called... I assume to do the SSL handshaking. The base stream doesn't have a Write method I can simply pass to, only a WriteByte method... however if I try to iterate through the given buffer and call: `foreach(...){ base.WriteByte(buffer[i]);}` it actually does a recursive call on RestartableStream.Write (the overridden method)... :-( – Jerren Saunders Mar 13 '13 at 15:07
  • Nevermind... instead of calling `base.WriteByte()`, I realized I need to call the inner stream's Write() method. It's kinda bad to be this slow on a Wednesday... – Jerren Saunders Mar 13 '13 at 15:11
3

I spent hours searching to not write a stream wrapper around NetworkStream and finally came across this and it worked for me. MSDN SocketFlag.Peek I kept finding suggestions to just write a wrapper or use separate ports, but I have a problem listening to authority or reason.

Here's my code. NLOLOL (No laughing out loud or lectures) I haven't completely figured out if I need to Peek at more than the first byte for all scenarios.

Private Async Sub ProcessTcpClient(__TcpClient As Net.Sockets.TcpClient)

        If __TcpClient Is Nothing OrElse Not __TcpClient.Connected Then Return

        Dim __RequestBuffer(0) As Byte
        Dim __BytesRead As Integer

        Using __NetworkStream As Net.Sockets.NetworkStream = __TcpClient.GetStream

            __BytesRead = __TcpClient.Client.Receive(__RequestBuffer, 0, 1, SocketFlags.Peek)
            If __BytesRead = 1 AndAlso __RequestBuffer(0) = 22 Then
                Await Me.ProcessTcpClientSsl(__NetworkStream)
            Else
                Await Me.ProcessTcpClientNonSsl(__NetworkStream)
            End If

        End Using

        __TcpClient.Close()

End Sub
Big Dummy
  • 31
  • 1
  • thank you for your stubborness ;-). It helped me greatly to figure out how I could change my stream from non ssl to ssl. The wrappers did not provide any help for me as well. Thanks again! – Stormer Sep 03 '19 at 13:32