Search code examples
pythontcppython-asyncio

How to implement logic like sock.recv(), based on python-asyncio's transport & protocol api?


I'm trying to build some simple applications based on asyncio tcp. In traditional socket programming, we use sock.recv() and sock.send() to manage the receiving and sending of sockets, but I noticed that using sockets directly is not recommended according to asyncio documentation, correspondingly, they suggest using the transport abstraction.

I want to know how to use transport to reproduce a logic similar to traditional socket programming. For example I'd like to implement the following logic:

async def main():
    loop = asyncio.get_running_loop()
    transport, protocal = await loop.create_connection(EchoClientProtocol(), '', 25000)
    await transport.write("hello")
    await transport.read(5) # Error
    ....

The above code does not work because transport does not provide a read method in the begining, the read event must be implemented in the corresponding protocol. This prevents me from clearly separating different tcp packages. What is the right way to do it? Thanks.


Solution

  • You can implement TCP server and client using asyncio streams

    Edit based on @user4815162342 great suggests:

    I increased for read maximum number of bytes on chuck from 1 byte to 8192 bytes, it was my bad idea to use use the smallest possible number in the example and it could be misleading for other people.

    In addition BytesIO is much better suited for concatenation than just += bytes. I introduced BytesIO to this code example.

    Server Script Example:

    import asyncio
    import socket
    from io import BytesIO
    
    
    async def handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
        print(len(asyncio.all_tasks()))  # let's show number of tasks
        ip, port = writer.get_extra_info('peername')  # get info about incoming connection
        print(f"Incoming connection from {ip}: {port}")
        # better use BytesIO than += if you gonna concat many times
        all_data = BytesIO()
        while True:
            try:
                # read chunk up to 8 kbytes
                data = await asyncio.wait_for(reader.read(8192), timeout=2.0)
                all_data.write(data)
                if reader.at_eof():
                    print(f"Received data:\n{all_data.getvalue().decode('utf8')}")
                    break
            except (asyncio.CancelledError, asyncio.TimeoutError):
                print("Too slow connection aborted")
                break
    
        writer.write(b"FROM_SERVER:\n")  # prepare data
        writer.write(all_data.getvalue())  # prepare more data
        # simulate slow server
        # await asyncio.sleep(5)
        await writer.drain()  # send all prepared data
    
        if writer.can_write_eof():
            writer.write_eof()
    
        writer.close()  # do not forget to close stream
    
    
    async def main_server():
        server = await asyncio.start_server(
            client_connected_cb=handler,
            host="localhost",
            port=8888,
            family=socket.AF_INET,  # ipv4
        )
    
        ip, port = server.sockets[0].getsockname()
        print(f"Serving on: {ip}:{port}")
        print("*" * 200)
    
        async with server:
            await server.serve_forever()
    
    if __name__ == '__main__':
        asyncio.run(main_server())
    

    Client Script Example:

    import asyncio
    from io import BytesIO
    
    
    async def main():
        reader, writer = await asyncio.open_connection(host="localhost", port=8888)
        # remove comment to test slow client
        # await asyncio.sleep(20)
        for i in range(10):
            writer.write(f"hello-{i}\n".encode("utf8"))  # prepare data
            await writer.drain()  # send data
    
        if writer.can_write_eof():
            writer.write_eof()  # tell server that we sent all data
    
        # better use BytesIO than += if you gonna concat many times
        data_from_server = BytesIO()  # now get server answer
        try:
            while True:
                # read chunk up to 8 kbytes
                data = await asyncio.wait_for(reader.read(8192), timeout=1.0)
                data_from_server.write(data)
                # if server told use that no more data
                if reader.at_eof():
                    break
    
            print(data_from_server.getvalue().decode('utf8'))
            writer.close()
        except ConnectionAbortedError:
            # if our client was too slow
            print("Server timed out connection")
            writer.close()
        except (asyncio.TimeoutError, asyncio.CancelledError):
            # if server was too slow
            print("Did not get answer from server due to timeout")
            writer.close()
    
    if __name__ == '__main__':
        asyncio.run(main())