Search code examples
pythonwebsocketquartpython-trio

Migrating a Quart project with websockets from asyncio to trio


I'm trying to convert my asyncio project to trio. I understand that I have to use memory channels instead of Queues but for some reason I don't have the result I'm expecting.

My main problem is that when I run two clients, the first one does not get notified if the second one leaves (broadcasting the 'part' message from the server raises an error). Another problem is that sometimes the client exits immediately when opening the websocket. When I use asyncio, everything works fine.

Here is the stack trace I get when the second client is disconnecting:

[2021-07-30 18:39:51,899] ERROR in app: Exception on websocket /ws
Traceback (most recent call last):
  File "/tmp/debug/venv/lib/python3.9/site-packages/quart_trio/app.py", line 175, in handle_websocket
    return await self.full_dispatch_websocket(websocket_context)
  File "/tmp/debug/venv/lib/python3.9/site-packages/quart_trio/app.py", line 197, in full_dispatch_websocket
    result = await self.handle_user_exception(error)
  File "/tmp/debug/venv/lib/python3.9/site-packages/quart_trio/app.py", line 166, in handle_user_exception
    raise error
  File "/tmp/debug/venv/lib/python3.9/site-packages/quart_trio/app.py", line 195, in full_dispatch_websocket
    result = await self.dispatch_websocket(websocket_context)
  File "/tmp/debug/venv/lib/python3.9/site-packages/quart/app.py", line 1651, in dispatch_websocket
    return await self.ensure_async(handler)(**websocket_.view_args)
  File "/tmp/debug/server.py", line 103, in wsocket
    nursery.start_soon(receiving, u)
  File "/tmp/debug/venv/lib/python3.9/site-packages/trio/_core/_run.py", line 815, in __aexit__
    raise combined_error_from_nursery
trio.MultiError: Cancelled(), Cancelled(), Cancelled()

Details of embedded exception 1:

  Traceback (most recent call last):
    File "/tmp/debug/venv/lib/python3.9/site-packages/trio/_core/_run.py", line 1172, in raise_cancel
      raise Cancelled._create()
  trio.Cancelled: Cancelled

Details of embedded exception 2:

  Traceback (most recent call last):
    File "/tmp/debug/server.py", line 68, in receiving
      data = await websocket.receive_json()
    File "/tmp/debug/venv/lib/python3.9/site-packages/quart/wrappers/websocket.py", line 68, in receive_json
      data = await self.receive()
    File "/tmp/debug/venv/lib/python3.9/site-packages/quart/wrappers/websocket.py", line 57, in receive
      return await self._receive()
    File "/tmp/debug/venv/lib/python3.9/site-packages/trio/_channel.py", line 314, in receive
      return await trio.lowlevel.wait_task_rescheduled(abort_fn)
    File "/tmp/debug/venv/lib/python3.9/site-packages/trio/_core/_traps.py", line 166, in wait_task_rescheduled
      return (await _async_yield(WaitTaskRescheduled(abort_func))).unwrap()
    File "/tmp/debug/venv/lib/python3.9/site-packages/outcome/_impl.py", line 138, in unwrap
      raise captured_error
    File "/tmp/debug/venv/lib/python3.9/site-packages/trio/_core/_run.py", line 1172, in raise_cancel
      raise Cancelled._create()
  trio.Cancelled: Cancelled

Details of embedded exception 3:

  Traceback (most recent call last):
    File "/tmp/debug/server.py", line 54, in sending
      data = await u.queue_recv.receive()
    File "/tmp/debug/venv/lib/python3.9/site-packages/trio/_channel.py", line 314, in receive
      return await trio.lowlevel.wait_task_rescheduled(abort_fn)
    File "/tmp/debug/venv/lib/python3.9/site-packages/trio/_core/_traps.py", line 166, in wait_task_rescheduled
      return (await _async_yield(WaitTaskRescheduled(abort_func))).unwrap()
    File "/tmp/debug/venv/lib/python3.9/site-packages/outcome/_impl.py", line 138, in unwrap
      raise captured_error
    File "/tmp/debug/venv/lib/python3.9/site-packages/trio/_core/_run.py", line 1172, in raise_cancel
      raise Cancelled._create()
  trio.Cancelled: Cancelled

  During handling of the above exception, another exception occurred:

  Traceback (most recent call last):
    File "/tmp/debug/server.py", line 63, in sending
      await broadcast({'type': 'part', 'data': u.name})
    File "/tmp/debug/server.py", line 75, in broadcast
      await user.queue_send.send(message)
    File "/tmp/debug/venv/lib/python3.9/site-packages/trio/_channel.py", line 159, in send
      await trio.lowlevel.checkpoint_if_cancelled()
    File "/tmp/debug/venv/lib/python3.9/site-packages/trio/_core/_run.py", line 2361, in checkpoint_if_cancelled
      await _core.checkpoint()
    File "/tmp/debug/venv/lib/python3.9/site-packages/trio/_core/_run.py", line 2339, in checkpoint
      await _core.wait_task_rescheduled(lambda _: _core.Abort.SUCCEEDED)
    File "/tmp/debug/venv/lib/python3.9/site-packages/trio/_core/_traps.py", line 166, in wait_task_rescheduled
      return (await _async_yield(WaitTaskRescheduled(abort_func))).unwrap()
    File "/tmp/debug/venv/lib/python3.9/site-packages/outcome/_impl.py", line 138, in unwrap
      raise captured_error
    File "/tmp/debug/venv/lib/python3.9/site-packages/trio/_core/_run.py", line 1172, in raise_cancel
      raise Cancelled._create()
  trio.Cancelled: Cancelled

Here is the code (set TRIO to False to use asyncio):

server.py

#!/usr/bin/env python
from quart import Quart, websocket, request, jsonify, json
from quart_trio import QuartTrio
from functools import wraps
import uuid
import trio
import asyncio
from quart_auth import AuthUser, AuthManager, login_user, _AuthSerializer

TRIO = True

if TRIO:
    app = QuartTrio(__name__)
else:
    app = Quart(__name__)
app.secret_key = '**changeme**'

authorized_users = set()

class User(AuthUser):
    @staticmethod
    def current():
        token = websocket.cookies['QUART_AUTH']
        serializer = _AuthSerializer('**changeme**', 'quart auth salt')
        user_id = serializer.loads(token)
        for u in authorized_users:
            if u.auth_id == user_id:
                return u
        return None

    def __init__(self, auth_id):
        super().__init__(auth_id)
        self.name = None
        self.queue = None # asyncio
        self.queue_send = None #trio
        self.queue_recv = None #trio
        self.connected = False
        self.websockets = set()    

    def to_dict(self):
        return {
            'id': self.auth_id,
            'name': self.name
        }

auth_manager = AuthManager()
auth_manager.user_class = User

async def sending(u: User):
    await broadcast({'type': 'join', 'data': u.name})
    try:
        while True:
            if TRIO:
                data = await u.queue_recv.receive()
            else:
                data = await u.queue.get()
            for s in u.websockets:
                await s.send_json(data)
    finally:
        u.websockets.remove(websocket._get_current_object())
        if len(u.websockets) == 0:
            u.connected = False
            await broadcast({'type': 'part', 'data': u.name})


async def receiving(u: User):
    while True:
        data = await websocket.receive_json()
        if data['type'] == 'msg':
            await broadcast({'type': 'msg', 'user': u.name, 'data': data['data']})

async def broadcast(message):
    for user in [u for u in authorized_users if u.connected]:
        if TRIO:
            await user.queue_send.send(message)
        else:
            await user.queue.put(message)

@app.route('/api/v1/auth', methods=['POST'])
async def auth_login():
    data = await request.json
    user_id = str(uuid.uuid4())[:8]
    u = User(user_id)
    u.name = data['login'] or 'Anonymous'+user_id
    if TRIO:
        u.queue_send, u.queue_recv = trio.open_memory_channel(float('inf'))
    else:
        u.queue = asyncio.Queue()
    login_user(u, True)
    authorized_users.add(u)
    return jsonify({'id': user_id, 'name': u.name}), 200

@app.websocket('/ws')
async def wsocket():
    u = User.current()
    if u is None:
        return
    u.websockets.add(websocket._get_current_object())
    u.connected = True
    if TRIO:
        async with trio.open_nursery() as nursery:
            nursery.start_soon(sending, u)
            nursery.start_soon(receiving, u)
    else:
        producer = asyncio.create_task(sending(u))
        consumer = asyncio.create_task(receiving(u))
        await asyncio.gather(producer, consumer)


auth_manager.init_app(app)

if __name__ == "__main__":
    app.run(host='localhost', port=8080)

client.py

#!/usr/bin/env python

import asks
import trio
import trio_websocket
import json

asks.init(trio)

class User:
    def __init__(self, name: str="") -> None:
        self.name = name

class Client(User):
    def __init__(self) -> None:
        super(Client, self).__init__()
        self.web_url = 'http://localhost:8080/api/v1'
        self.ws_url = 'ws://localhost:8080/ws'
        self.ws = None
        self.nursery = None
        self.cookiejar = {}
    
    async def send(self, msg: dict) -> None:
        if self.ws is not None:
            await self.ws.send_message(json.dumps(msg))

    async def reader(self, websocket) -> None:
        while True:
            try:
                message_raw = await websocket.get_message()
                msg = json.loads(message_raw)
                if msg['type'] == 'msg':
                    print(f"<{msg['user']}> {msg['data']}")
                elif msg['type'] == 'join':
                    print(f"* {msg['data']} joined")
                elif msg['type'] == 'part':
                    print(f"* {msg['data']} left")
            except trio_websocket.ConnectionClosed:
                break

    async def login(self) -> None:
        rlogin = await asks.post(self.web_url + '/auth', json={'login': self.name, 'password': 'password'})
        for c in rlogin.cookies:
            if c.name == 'QUART_AUTH':
                self.cookiejar = {'QUART_AUTH': c.value}

    async def connect(self) -> None:
        await self.login()
        async with trio_websocket.open_websocket_url(self.ws_url, extra_headers=[('Cookie', 'QUART_AUTH'+'='+self.cookiejar['QUART_AUTH'])]) as websocket:
            self.ws = websocket
            await self.send({'type': 'msg', 'data': 'hello'})
            async with trio.open_nursery() as nursery:
                self.nursery = nursery
                nursery.start_soon(self.reader, websocket)

    def run(self) -> None:
        trio.run(self.connect)

c = Client()
c.name = 'clientA'
c.run()

Edit: I tested using anyio and while anyio+trio acts the same, anyio+asyncio reproduces the problem (without any exception). So I guess it comes from the Queue replacement.


Solution

  • Ok, @tibs, I think I've found the issue. The problem is with the way that Trio handles cancellation. For full docs, have a read of this doc:

    https://trio.readthedocs.io/en/stable/reference-core.html#cancellation-and-timeouts

    However, to explain what's going on here, when a user disconnects, what Quart-Trio does is raises a Cancelled exception in every coroutine that's running/waiting under that that websocket. For a websocket-user, there are two spots that will currently be waiting:

    In async def sending(u: User):

    async def sending(u: User):
        await broadcast({'type': 'join', 'data': u.name})
        try:
            while True:
                if TRIO:
                    data = await u.queue_recv.receive()  <--- Code is waiting here, Cancelled is raised here
                else:
                    data = await u.queue.get()
                for s in u.websockets:
                    await s.send_json(data)
        finally:
            u.websockets.remove(websocket._get_current_object())
            if len(u.websockets) == 0:
                u.connected = False
                await broadcast({'type': 'part', 'data': u.name})
    

    In async def receiving(u: User):

    async def receiving(u: User):
        while True:
            data = await websocket.receive_json()   <--- Code is waiting here, Cancelled is raised here
            if data['type'] == 'msg':
                await broadcast({'type': 'msg', 'user': u.name, 'data': data['data']})
    

    Okay, so what happens from here? Well, in the sending() function we move down to the finally block, which begins executing, but then we call another awaitable function:

        finally:
            u.websockets.remove(websocket._get_current_object())
            if len(u.websockets) == 0:
                u.connected = False
                await broadcast({'type': 'part', 'data': u.name})  <--- we call an awaitable here
    

    From the Trio docs:

    Cancellations in Trio are “level triggered”, meaning that once a block has been cancelled, all cancellable operations in that block will keep raising Cancelled.

    So when await broadcast(...) is called, it is immediately Cancelled, unlike asyncio which behaves differently. This explains why your "part" message is never sent. So when trio, if you want to do some cleanup work while you are being cancelled, you should open a new cancel scope, and shield it from being cancelled, like this:

    async def sending(u: User):
        await broadcast({'type': 'join', 'data': u.name})
        try:
            while True:
                if TRIO:
                    data = await u.queue_recv.receive()  <--- Code is waiting here, Cancelled is raised here
                else:
                    data = await u.queue.get()
                for s in u.websockets:
                    await s.send_json(data)
        finally:
            u.websockets.remove(websocket._get_current_object())
            if len(u.websockets) == 0:
                u.connected = False
                with trio.move_on_after(5) as leaving_cancel_scope:
                    # Shield from the cancellation for 5s to run the broadcast of leaving
                    leaving_cancel_scope.shield = True
                    await broadcast({'type': 'part', 'data': u.name})
    

    Or alternatively you could start the broadcast coroutine on the app nursery. Be aware that if the broadcast(...) crashes you will the crash the whole running app, unless you put a try/except in the broadcast(...) function:

    async def sending(u: User):
        await broadcast({'type': 'join', 'data': u.name})
        try:
            while True:
                if TRIO:
                    data = await u.queue_recv.receive()
                else:
                    data = await u.queue.get()
                for s in u.websockets:
                    await s.send_json(data)
        finally:
            u.websockets.remove(websocket._get_current_object())
            if len(u.websockets) == 0:
                u.connected = False
                app.nursery.start_soon(broadcast, {'type': 'part', 'data': u.name})
    

    After this you still get the Cancelled exceptions flowing through to your websocket function, so you may want to catch them there. Be aware you will need to catch BaseException to catch errors, some thing like:

    @app.websocket('/ws')
    async def wsocket():
        u = User.current()
        if u is None:
            return
        u.websockets.add(websocket._get_current_object())
        u.connected = True
        if TRIO:
            try:
                async with trio.open_nursery() as nursery:
                    nursery.start_soon(sending, u)
                    nursery.start_soon(receiving, u)
            except BaseException as e:
                print(f'websocket funcs crashed with exception: {e}')
    

    In particular this is because trio doesn't allow you to silently drop exceptions, you need to either catch them or crash. I hope this is enough to get you started on fixing the issues you are seeing.