Search code examples
websocketpytestfastapi

The websocket freezes in pytest on receive_text()


I have a working old version of the code, where the test completes successfully, after I updated the code - the test just started hanging on the receive_text, I have no idea why except WebSocketDisconnect works in the old code but not in the update version

config.py

engine_test = create_async_engine(TEST_POSTGRES_URI, poolclass=NullPool)
async_session_maker = async_sessionmaker(
    engine_test, class_=AsyncSession, expire_on_commit=False
)


async def override_session():
    async with async_session_maker() as session:
        yield session

topus.dependency_overrides[get_session] = override_session
async def defaults_user():
    try:
        async with async_session_maker() as session:
            test_user = UserDB(
                username="TestUserDB",
                password=hash_password("12345678"),
            )
            test_user2 = UserDB(
                username="TestUserDB2",
                password=hash_password("12345678"),
            )

            session.add_all([test_user, test_user2])
            await session.commit()
    except Exception as e:
        print(f"An error occurred while adding default user: {e}")

@pytest.fixture(autouse=True, scope='session')
async def lifespan():
    async with engine_test.begin() as conn:
        await conn.run_sync(Base.metadata.create_all)
    await defaults_user()
    yield
    async with engine_test.begin() as conn:
        await conn.run_sync(Base.metadata.drop_all)

old_chat.py

class PrivateManager:
    """ Менеджер для приватних повідомлень
    """

    def __init__(self):
        self.connections: Dict[str, List[WebSocket]] = {}

    def register_websocket(self, chat_id: str, websocket: WebSocket) -> None:

        if chat_id not in self.connections:
            self.connections[chat_id] = []
        self.connections[chat_id].append(websocket)

    async def broadcast(self, chat_id: str, message: str, sender_id: int, friend_id: int, add_to_db: bool, db: AsyncSession = None) -> None:
        if chat_id not in self.connections:
            return

        if add_to_db:
            await self.save_message_to_db(chat_id=chat_id, message=message, sender=sender_id, friend_id=friend_id, db=db)

        for websocket in self.connections[chat_id]:
            await websocket.send_text(message)

    def disconnect(self, chat_id, websocket: WebSocket):
        chat: list = self.connections.get(chat_id)
        chat.remove(websocket)`

@chat.websocket('/private_chat/{friend_id}')
async def private_chat(friend_id: int, websocket: WebSocket, token: str = Depends(decode_token), db: AsyncSession = Depends(get_session)):
    if token['id'] == friend_id:
        return JSONResponse(status_code=status.HTTP_403_FORBIDDEN, content='Собі не можна написати')
    chat_id = chat_id_generator(token['id'], friend_id)

    private_manager.register_websocket(chat_id, websocket)
    await websocket.accept()

    try:
        while True:
            data = await websocket.receive_text()
            await private_manager.broadcast(chat_id=chat_id, message=data, sender_id=token['id'], friend_id=friend_id, db=db, add_to_db=True)
    except WebSocketDisconnect:
        logging.error(msg='da', exc_info=True)
        private_manager.disconnect(chat_id, websocket)

pytest

        """
        Юзер надсилає приватне повідомлення
        """
        client = TestClient(topus)
        login = {
            "username": "TestUserDB",
            "password": "12345678",
        }
        response = client.post('/auth/login', json=login)
        assert response.status_code == 200

        with client.websocket_connect('/chat/private_chat/2') as websocket:
            websocket.send_text('Hi, Private!')
            ans = websocket.receive_text()
            assert ans == 'Hi, Private!'
            websocket.close()

new_chat.py

freeze on receiver_text in pytest

class PrivateManager:
    """ Менеджер для приватних повідомлень
    """

    def __init__(self):
        self.connections: Dict[str, List[WebSocket]] = {}

    async def register_websocket(self, chat_id: str, user_id: int, friend_id: int,  websocket: WebSocket, db: AsyncSession) -> None:

        if chat_id not in self.connections:
            self.connections[chat_id] = []
        self.connections[chat_id].append(websocket)
        await get_or_create_chat(
            chat_id=chat_id, user_id=user_id, friend_id=friend_id, db=db)

    async def broadcast(self, chat_id: str, message: str, sender_id: int, friend_id: int, add_to_db: bool, userwebsocket: WebSocket, db: AsyncSession = None) -> None:
        if chat_id not in self.connections:
            return

        if add_to_db:
            await self.save_message_to_db(chat_id=chat_id, message=message, sender=sender_id, friend_id=friend_id, db=db)

        for websocket in self.connections[chat_id]:
            if websocket != userwebsocket:
                await websocket.send_text(message)

    def disconnect(self, chat_id, websocket: WebSocket):
        chat: list = self.connections.get(chat_id)
        if chat:
            chat.remove(websocket)

    @staticmethod
    async def save_message_to_db(chat_id: str, message: str, sender: int, friend_id, db: AsyncSession) -> None:
        await save_message(chat_id=chat_id, message=message, sender=sender, friend_id=friend_id, db=db)
   

async def private_chat(user_id: int, friend_id: int, websocket: WebSocket, db: AsyncSession = Depends(get_session)):
    if user_id == friend_id:
        raise WebSocketException(
            code=status.WS_1007_INVALID_FRAME_PAYLOAD_DATA, reason='Собі не можна написати')
    chat_id = chat_id_generator(user_id, friend_id)

    await private_manager.register_websocket(chat_id=chat_id, user_id=user_id, friend_id=friend_id, websocket=websocket, db=db)
    await websocket.accept()

    try:
        while True:
            data = await websocket.receive_text()
            await private_manager.broadcast(chat_id=chat_id, message=data, sender_id=user_id, friend_id=friend_id, db=db, add_to_db=True, userwebsocket=websocket)

    except WebSocketDisconnect:
        private_manager.disconnect(chat_id, websocket)

returns in the old version "disconecting" in message = await self._receive() and websocket.receive in updated

venv

class WebSocket(HTTPConnection):
    def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
        super().__init__(scope)
        assert scope["type"] == "websocket"
        self._receive = receive
        self._send = send
        self.client_state = WebSocketState.CONNECTING
        self.application_state = WebSocketState.CONNECTING

    async def receive(self) -> Message:
        """
        Receive ASGI websocket messages, ensuring valid state transitions.
        """
        if self.client_state == WebSocketState.CONNECTING:
            message = await self._receive()
            message_type = message["type"]
            if message_type != "websocket.connect":
                raise RuntimeError(
                    'Expected ASGI message "websocket.connect", '
                    f"but got {message_type!r}"
                )
            self.client_state = WebSocketState.CONNECTED
            return message
        elif self.client_state == WebSocketState.CONNECTED:
            message = await self._receive()
            message_type = message["type"]
            if message_type not in {"websocket.receive", "websocket.disconnect"}:```


Solution

  • the test waits indefinitely for a message because we have disabled our websocket to receive the message

            for websocket in self.connections[chat_id]:
                if websocket != userwebsocket:
                    await websocket.send_text(message)