Search code examples
python-3.xsocketspython-asyncio

How to use multiple socket services in an asynchronous context manager?


I've created a ConnectedSocket context manager for connecting and closing socket connections. However, when using create_task within the context manager to establish connections, the established connection immediately triggers aexit, and normal socket communication cannot proceed.

import asyncio
import socket
from asyncio import AbstractEventLoop
from types import TracebackType
from typing import Optional, Type
from typing import List


class ConnectedSocket:
    def __init__(self, server_socket):
        self._connection = None
        self._server_socket = server_socket

    async def __aenter__(self):
        print("Entering context manager, waiting for connection")
        loop = asyncio.get_event_loop()
        conn, addr = await loop.sock_accept(self._server_socket)
        conn.setblocking(False)  #
        self._connection = conn
        print("Accepted a connection")
        return self._connection

    async def __aexit__(self,
                        exc_type: Optional[Type[BaseException]],
                        exc_val: Optional[BaseException],
                        exc_tb: Optional[TracebackType]):
        print("Exiting context manager")
        self._connection.close()
        print("Closed connection")


class AsyncSocketServer:
    def __init__(self):
        self.loop: AbstractEventLoop = asyncio.get_event_loop()
        self.tasks: List[asyncio.Task] = []

    async def create_server(self):
        server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)  #
        server_address = ('127.0.0.1', 8000)
        server_socket.bind(server_address)
        server_socket.setblocking(False)
        server_socket.listen()
        await self.connection_listener(server_socket)

    async def connection_listener(self, server_socket):
        while True:
            async with ConnectedSocket(server_socket) as conn:
                task: asyncio.Task = asyncio.create_task(self.echo(conn))
                self.tasks.append(task)

                #Block
                # while data := await self.loop.sock_recv(conn, 1024):
                #     if data == b'exit\r\n':
                #         break
                #     print(data)

    async def echo(self, conn: socket) -> None:
        while data := await self.loop.sock_recv(conn, 1024):
            if data == b'exit\r\n':
                break
            print(data)

    def run(self):
        self.loop.run_until_complete(self.create_server())


if __name__ == "__main__":
    server = AsyncSocketServer()
    server.run()

telnet localhost 8000

Trying ::1... 
telnet: connect to address ::1: Connection refused 
Trying 127.0.0.1... 
Connected to localhost. Escape
character is '^]'. 
Connection closed by foreign host.

Solution

  • I used "self._connection.close()" in aexit which caused the connection to close every time the ConnectedSocket popped out, so I removed it and it worked.