Search code examples
c#socketsasynchronous.net-standard-2.0networkstream

How to read continuously and asynchronously from network stream using abstract stream?


Description

I would like to read asynchronously from a NetworkStream or SSLStream using their abstract Stream parent class. There are different ways to read asynchronously from stream:

  • Asynchronous Programming Model (APM): It uses BeginRead and EndRead operations.
  • Task Parallel Library (TPL): It uses Task and creates task continuations.
  • Task-based Asynchronous Pattern (TAP): Operations are suffixed Async and async and await keyword can be used.

I am mostly interested using the TAP pattern to achieve asynchronous read operation. Code below, asynchronously read to the end of stream and returns with the data as byte array:

    internal async Task<byte[]> ReadToEndAsync(Stream stream)
    {
        byte[] buffer = new byte[4096];
        using (MemoryStream memoryStream = new MemoryStream())
        {
            int bytesRead = await stream.ReadAsync(buffer, 0, buffer.Length);
            while (bytesRead != 0)
            {
                // Received datas were aggregated to a memory stream.
                await memoryStream.WriteAsync(buffer, 0, bytesRead);
                bytesRead = await stream.ReadAsync(buffer, 0, buffer.Length);
            }

            return memoryStream.ToArray();
        }
    }

The buffer size is 4096kB. If the transferred data is more than the buffer size then it continues to read until 0 (zero), the end of the stream. It works with FileStream but it hangs indefinitely at the ReadAsync operation using NetworkStream or SslStream. These two stream behave differently than other streams.

The problem lies behind that the network stream ReadAsync will only return with 0 (zero) when the Socket communication is closed. However I do not want to close the communication every time a data is transferred through the network.

Question

How can I avoid the blocking call of the ReadAsync without closing the Socket communication?


Solution

  • I have an issue similar to yours, I want to read from a network stream that might send packets at random intervals. I am receiving length-prefixed packets, that is, the packets begin with a byte (e.g., 20) which represents how many bytes I need to read from the network stream to get the rest of the packet.

    After my testing, I have had many issues with prior solution I've tried, such as the payloads being truncated, part of the payload missing, super intensive CPU usage, no payload at all, or exceptions being thrown, clients randomly disconnecting immediately, or what appears to be a deadlock or just waiting forever.

    Starting with .NET 7, there are two very interesting and useful methods:

    • ReadAtLeastAsync
    • ReadExactlyAsync

    In my limited testing, they have been very, very reliable and I haven't had any issues, even with sending data at random intervals. The client/server can send data with delays, and it just awaits perfectly. I have attached an example below showing how these can be used, including a fake flaky network situation where these methods handle it perfectly.

    There is a lot of machinery for these methods if you look at their source.

    ReadAtLeastAsync

    Docs: https://learn.microsoft.com/en-us/dotnet/api/system.io.stream.readatleastasync?view=net-8.0

    n is a hard-coded number that you supply. It represents the number of bytes to read, at least. Don't set this to zero, because it'll just spin.

    • This is guaranteed to read n bytes into your buffer if throwOnEndOfStream is set to true because you'll get an exception if it can't do this. It will wait until the stream has sent at least n bytes, and then give you a buffer with those bytes.
    • This is not guaranteed to read n bytes into your buffer if throwOnEndOfStream is false. Instead, you have to check the return value. If it is zero, then the client closed the connection. If it is not zero, then truncate the buffer up to the return value, and then call ReadAtLeastAsync again forever until it is not zero.

    ReadExactlyAsync

    Docs: https://learn.microsoft.com/en-us/dotnet/api/system.io.stream.readexactlyasync?view=net-8.0

    This reads exactly n bytes into your buffer. If it can't do this, then it'll throw an exception. It will wait on the stream until it has received n bytes, then will continue.

    Paste this entire code block into a new project (e.g., Program.cs) and then run it (look at Main() for the different examples):

    using System.Net.Sockets;
    using System.Net;
    using System.Text;
    
    namespace ConsoleApp1
    {
        public static class Logger
        {
            public static void LogWithTime(string message)
            {
                Console.WriteLine($"[{DateTime.Now}]: {message}");
            }
        }
    
        /// <summary>
        /// This class is not super important, because it will be on the server's side.
        /// This simulates the server sending packets whenever they want to, and might
        /// delay many seconds between sending each packet at random times.
        ///
        /// The server sends messages as message length prefixed payloads. The payload contains
        /// the message, in UTF-8, as bytes.
        /// Messages have a maximum size of 255 bytes in this example.
        /// </summary>
        public static class TcpServerExample1
        {
            public static async Task SendMessages()
            {
                Logger.LogWithTime("Server: starting server...");
                var client = new TcpClient();
                await client.ConnectAsync(IPAddress.Loopback, 8000);
                Logger.LogWithTime("Server: connected to client");
    
                await using (var networkStream = client.GetStream())
                await using (var writer = new BinaryWriter(networkStream))
                {
                    for (var i = 0; i < 10; i++)
                    {
                        // send a message to client
                        var message = $"Hello #{i}!";
                        var messageAsBytes = Encoding.UTF8.GetBytes(message);
                        // messageAsBytes = [72, 101, 108, 108, 111, 32, 87, 111, 114, 108, 100, 33]
    
                        var messageLengthPrefix = new[] { (byte)messageAsBytes.Length };
                        // messagePrefix = [12]
    
                        await Task.Delay(TimeSpan.FromSeconds(5));
                        writer.Write(messageLengthPrefix);
                        Logger.LogWithTime($"Server: wrote message prefix for message #{i}");
    
                        await Task.Delay(TimeSpan.FromSeconds(5));
                        writer.Write(messageAsBytes);
                        Logger.LogWithTime($"Server: wrote message payload for message #{i}");
                    }
                }
    
                client.Close();
            }
        }
    
        /// <summary>
        /// This is us. We listen to the server, who sends us packets.
        /// </summary>
        public static class TcpClientExample1
        {
            public static async Task StartListening()
            {
                var listener = new TcpListener(IPAddress.Loopback, 8000);
                listener.Start();
    
                var client = await listener.AcceptTcpClientAsync();
                await using var networkStream = client.GetStream();
    
                while (true)
                {
                    // again, assume that maximum message length is 255 bytes
                    var msgLengthBuffer = new byte[1];
                    await networkStream.ReadExactlyAsync(msgLengthBuffer);
    
                    var msgPayloadBuffer = new byte[msgLengthBuffer[0]];
                    await networkStream.ReadExactlyAsync(msgPayloadBuffer);
    
                    Logger.LogWithTime($"Client: received message '{Encoding.UTF8.GetString(msgPayloadBuffer)}'");
                }
    
                // remember to close the connection; in this example, we assume that we read forever
                //client.Close();
            }
        }
    
        /// <summary>
        /// Here, the server is passively reading messages, while the client actively sends them
        /// </summary>
        public static class TcpServerExample2
        {
            public static async Task StartListening()
            {
                Logger.LogWithTime("Server: starting server...");
                var client = new TcpClient();
                await client.ConnectAsync(IPAddress.Loopback, 8000);
                Logger.LogWithTime("Server: connected to client");
    
                await using (var networkStream = client.GetStream())
                {
                    var messageLength = new byte[1];
                    await networkStream.ReadExactlyAsync(messageLength);
    
                    var messagePayload = new byte[messageLength[0]];
                    await networkStream.ReadExactlyAsync(messagePayload);
    
                    Logger.LogWithTime($"Server: received '{Encoding.UTF8.GetString(messagePayload)}'");
                }
    
                client.Close();
            }
        }
    
        public static class TcpClientExample2
        {
            public static async Task SendMessages()
            {
                Logger.LogWithTime("Client: starting client...");
                var client = new TcpClient();
                await client.ConnectAsync(IPAddress.Loopback, 8000);
                Logger.LogWithTime("Client: connected to server");
    
                await using (var networkStream = client.GetStream())
                await using (var writer = new BinaryWriter(networkStream))
                {
                    for (var i = 0; i < 10; i++)
                    {
                        // send a message to client
                        var message = $"Hello #{i}!";
                        var messageAsBytes = Encoding.UTF8.GetBytes(message);
                        // messageAsBytes = [72, 101, 108, 108, 111, 32, 87, 111, 114, 108, 100, 33]
    
                        var messageLengthPrefix = new[] { (byte)messageAsBytes.Length };
                        // messagePrefix = [12]
    
                        await Task.Delay(TimeSpan.FromSeconds(5));
                        writer.Write(messageLengthPrefix);
                        Logger.LogWithTime($"Server: wrote message prefix for message #{i}");
    
                        await Task.Delay(TimeSpan.FromSeconds(5));
                        writer.Write(messageAsBytes);
                        Logger.LogWithTime($"Server: wrote message payload for message #{i}");
                    }
                }
    
                client.Close();
            }
        }
    
        /// <summary>
        /// Here, the server is passively reading messages, while the client actively sends them
        /// </summary>
        public static class TcpServerExample3
        {
            public static async Task StartListening()
            {
                Logger.LogWithTime("Server: starting server...");
                var client = new TcpClient();
                await client.ConnectAsync(IPAddress.Loopback, 8000);
                Logger.LogWithTime("Server: connected to client");
    
                await using (var networkStream = client.GetStream())
                {
                    while (true)
                    {
                        var buffer = new byte[8]; // doesn't have to be eight
                        var minBytesToRead = 1;
                        var numReadBytes = await networkStream.ReadAtLeastAsync(buffer, minBytesToRead, false);
                        if (numReadBytes == 0)
                        {
                            // this occurs when the client intentionally closes the connection
                            // the payload from the client may be smaller than our buffer we allocated for it,
                            // depending on how the packets arrive and their frequency, or if the server is done,
                            // so make sure to only take the first {numReadBytes} from the buffer
                            Logger.LogWithTime("Client closed connection");
                            break;
                        }
    
                        Logger.LogWithTime($"Server: received [{string.Join(", ", buffer[..numReadBytes])}]");
                    }
                }
    
                client.Close();
            }
        }
    
        public static class TcpClientExample3
        {
            public static async Task SendMessages()
            {
                Logger.LogWithTime("Client: starting client...");
                var listener = new TcpListener(IPAddress.Loopback, 8000);
                listener.Start();
    
                var client = await listener.AcceptTcpClientAsync();
                await using var networkStream = client.GetStream();
    
                Logger.LogWithTime("Client: connected to server");
                await using (var writer = new BinaryWriter(networkStream))
                {
                    for (var i = 0; i < 20; i++)
                    {
                        // random payloads of length 1 to 10
                        var randomPayload = Enumerable.Repeat(0, Random.Shared.Next(1, 20)).Select(_ => (byte)Random.Shared.Next(0, byte.MaxValue)).ToArray();
    
                        writer.Write(randomPayload);
    
                        Logger.LogWithTime($"Client: sent [{string.Join(", ", randomPayload)}]");
    
                        // simulate slow network
                        await Task.Delay(TimeSpan.FromMilliseconds(Random.Shared.Next(0, 5000)));
                    }
                }
    
                client.Close();
            }
        }
    
        internal class Program
        {
            static async Task Main(string[] args)
            {
                // passively connecting (listening), actively sending
                //await Example1();
    
                // actively connecting, passively listening
                //await Example2();
    
                // download something from the server that might arrive in
                // packets delayed by random amounts of time
            }
    
            private static async Task Example1()
            {
                // client is us
                var client = TcpClientExample1.StartListening();
                var server = TcpServerExample1.SendMessages();
    
                await Task.WhenAll(client, server);
            }
    
            private static async Task Example2()
            {
                var client = TcpClientExample2.SendMessages();
                var server = TcpServerExample2.StartListening();
    
                await Task.WhenAll(client, server);
            }
    
            private static async Task Example3()
            {
                var server = TcpServerExample3.StartListening();
                var client = TcpClientExample3.SendMessages();
                
                await Task.Delay(TimeSpan.FromSeconds(3));
                await Task.WhenAll(client, server);
            }
        }
    }