Search code examples
c#websockettcpclient

Proxying WebSocket messages between two streams


I have a HTTP proxy server which acts as a middle-man. It basically does the following:

  • Listen for client-browser request
  • Forward the request to the server
  • Parse the server response
  • Forward the response back to client-browser

So basically there is one NetworkStream, or even more often a SslStream between a client-browser and the proxy, and another one between the proxy and a server.

A requirement has arisen to also forward WebSocket traffic between a client and a server.

So now when a client-browser requests a connection upgrade to websocket, and the remote server responds with HTTP code 101, the proxy server maintains these connections in order to forward further messages from client to server and vice versa.

So after the proxy has received a message from the remote server saying it's ready to switch protocols, it needs to enter a loop where both client and server streams are polled for data, and where any received data is forwarded to the other party.

The problem

WebSocket allows both sides to send messages at any time. This is especially a problem with control messages such as ping/pong, where any side could send a ping at any time and the other side is expected to reply with a pong in a timely manner. Now consider having two instances of SslStream which don't have DataAvailable property, where the only way to read data is to call Read/ReadAsync which might not return until some data is available. Consider the following pseudo-code:

public async Task GetMessage()
{
    // All these methods that we await read from the source stream
    byte[] firstByte = await GetFirstByte(); // 1-byte buffer
    byte[] messageLengthBytes = await GetMessageLengthBytes();
    uint messageLength = GetMessageLength(messageLengthBytes);
    bool isMessageMasked = DetermineIfMessageMasked(messageLengthBytes);
    byte[] maskBytes;
    if (isMessageMasked)
    {
        maskBytes = await GetMaskBytes();
    }

    byte[] messagePayload = await GetMessagePayload(messageLength);

    // This method writes to the destination stream
    await ComposeAndForwardMessageToOtherParty(firstByte, messageLengthBytes, maskBytes, messagePayload);
}

The above pseudo code reads from one stream and writes to the other. The problem is that the above procedure needs to be run for both streams simultaneously, because we don't know which side would send a message to the other at any given point in time. However, it is impossible to perform a write operation while there is a read operation active. And because we don't have the means necessary to poll for incoming data, read operations have to be blocking. That means if we start read operations for both streams at the same time, we can forget about writing to them. One stream will eventually return some data, but we won't be able to send that data to the other stream as it will still be busy trying to read. And that might take a while, at least until the side that owns that stream sends a ping request.


Solution

  • Thanks to comments from @MarcGravell we've learned that independent read/write operations are supported with network streams, i.e. NetworkStream acts as two independent pipes - one read, one write - it is fully duplex.

    Therefore, proxying WebSocket messages can be as easy as just starting two independent tasks, one to read from client stream and write to server stream, and another to read from server stream and write to client stream.

    If it can be of any help to anyone searching for it, here is how I implemented that:

    public class WebSocketRequestHandler
    {
        private const int MaxMessageLength = 0x7FFFFFFF;
    
        private const byte LengthBitMask = 0x7F;
    
        private const byte MaskBitMask = 0x80;
    
        private delegate Task WriteStreamAsyncDelegate(byte[] buffer, int offset, int count, CancellationToken cancellationToken);
    
        private delegate Task<byte[]> BufferStreamAsyncDelegate(int count, CancellationToken cancellationToken);
    
        public async Task HandleWebSocketMessagesAsync(CancellationToken cancellationToken = default(CancellationToken))
        {
            var clientListener = ListenForClientMessages(cancellationToken);
            var serverListener = ListenForServerMessages(cancellationToken);
            await Task.WhenAll(clientListener, serverListener);
        }
    
        private async Task ListenForClientMessages(CancellationToken cancellationToken)
        {
            while (!cancellationToken.IsCancellationRequested)
            {
                cancellationToken.ThrowIfCancellationRequested();
                await ListenForMessages(YOUR_CLIENT_STREAM_BUFFER_METHOD_DELEGATE, YOUR_SERVER_STREAM_WRITE_METHOD_DELEGATE, cancellationToken);
            }
        }
    
        private async Task ListenForServerMessages(CancellationToken cancellationToken)
        {
            while (!cancellationToken.IsCancellationRequested)
            {
                cancellationToken.ThrowIfCancellationRequested();
                await ListenForMessages(YOUR_SERVER_STREAM_BUFFER_METHOD_DELEGATE, YOUR_CLIENT_STREAM_WRITE_METHOD_DELEGATE, cancellationToken);
            }
        }
    
        private static async Task ListenForMessages(BufferStreamAsyncDelegate sourceStreamReader,
            WriteStreamAsyncDelegate destinationStreamWriter,
            CancellationToken cancellationToken)
        {
            var messageBuilder = new List<byte>();
            var firstByte = await sourceStreamReader(1, cancellationToken);
            messageBuilder.AddRange(firstByte);
            var lengthBytes = await GetLengthBytes(sourceStreamReader, cancellationToken);
            messageBuilder.AddRange(lengthBytes);
            var isMaskBitSet = (lengthBytes[0] & MaskBitMask) != 0;
            var length = GetMessageLength(lengthBytes);
            if (isMaskBitSet)
            {
                var maskBytes = await sourceStreamReader(4, cancellationToken);
                messageBuilder.AddRange(maskBytes);
            }
    
            var messagePayloadBytes = await sourceStreamReader(length, cancellationToken);
            messageBuilder.AddRange(messagePayloadBytes);
            await destinationStreamWriter(messageBuilder.ToArray(), 0, messageBuilder.Count, cancellationToken);
        }
    
        private static async Task<byte[]> GetLengthBytes(BufferStreamAsyncDelegate sourceStreamReader, CancellationToken cancellationToken)
        {
            var lengthBytes = new List<byte>();
            var firstLengthByte = await sourceStreamReader(1, cancellationToken);
            lengthBytes.AddRange(firstLengthByte);
            var lengthByteValue = firstLengthByte[0] & LengthBitMask;
            if (lengthByteValue <= 125)
            {
                return lengthBytes.ToArray();
            }
    
            switch (lengthByteValue)
            {
                case 126:
                {
                    var secondLengthBytes = await sourceStreamReader(2, cancellationToken);
                    lengthBytes.AddRange(secondLengthBytes);
                    return lengthBytes.ToArray();
                }
                case 127:
                {
                    var secondLengthBytes = await sourceStreamReader(8, cancellationToken);
                    lengthBytes.AddRange(secondLengthBytes);
                    return lengthBytes.ToArray();
                }
                default:
                    throw new Exception($"Unexpected first length byte value: {lengthByteValue}");
            }
        }
    
        private static int GetMessageLength(byte[] lengthBytes)
        {
            byte[] subArray;
            switch (lengthBytes.Length)
            {
                case 1:
                    return lengthBytes[0] & LengthBitMask;
    
                case 3:
                    if (!BitConverter.IsLittleEndian)
                    {
                        return BitConverter.ToUInt16(lengthBytes, 1);
                    }
    
                    subArray = lengthBytes.SubArray(1, 2);
                    Array.Reverse(subArray);
                    return BitConverter.ToUInt16(subArray, 0);
    
                case 9:
                    subArray = lengthBytes.SubArray(1, 8);
                    Array.Reverse(subArray);
                    var retVal = BitConverter.ToUInt64(subArray, 0);
                    if (retVal > MaxMessageLength)
                    {
                        throw new Exception($"Unexpected payload length: {retVal}");
                    }
    
                    return (int) retVal;
    
                default:
                    throw new Exception($"Impossibru!!1 The length of lengthBytes array was: '{lengthBytes.Length}'");
            }
        }
    }
    

    It can be used by just calling await handler.HandleWebSocketMessagesAsync(cancellationToken) after the initial handshake has been performed.

    The SubArray method is taken from here: https://stackoverflow.com/a/943650/828023 (also from @Marc haha)