Search code examples
pythonwebsocketasync-awaitfastapipython-trio

Proper way to cancel remaining trio nursery tasks inside fastAPI websocket?


I'm still quite new to websockets and I've been given a problem I'm having a hard time solving.

I need to build a websocket endpoint with FastAPI in which a group of tasks are run asynchronously (to do so I went with trio) with each task returning a json value through the websocket in realtime.

I've managed to meet these requirements, with my code looking like this:

@router.websocket('/stream')
async def runTasks(
        websocket: WebSocket
):
    # Initialise websocket
    await websocket.accept()
    while True:
        # Receive data
        tasks = await websocket.receive_json()
        # Run tasks asynchronously (limiting to 10 tasks at a time)
        async with trio.open_nursery() as nursery:
            limit = trio.CapacityLimiter(10)
            for task in tasks:
                nursery.start_soon(run_task, limit, task, websocket)

With run_task looking something like this:

async def run_task(limit, task, websocket):
    async with limit:
       # Complete task / transaction
       await websocket.send_json({"placeholder":"data"})

But now, given two scenarios, I'm supposed to cancel/skip the current remaining nursery tasks, but I'm a bit loss as to how I could achieve that.

The two scenarios I'm given are as follows:

  • Scenario 1: Imagining the endpoint is called when a user presses a button, if the user were to press the button again while some tasks were still running they should be cancelled or skipped and the process should begin anew

  • Scenario 2: If the websocket were to be closed, the user were to refresh the page, or exit before the completion of the nursery tasks, the remaining tasks should be cancelled or skipped

I'm trying to read more into Python - How to cancel a specific task spawned by a nursery in python-trio but I'm still puzzled as to how I can cancel the previous nursery with cancel scope before entering the new one. Should I create an additional task that watches a variable or something and cancels once it changes? But then I'd have to stop that task once all the other tasks have finished


Solution

  • For scenario 1:

    1. Create dictionary in global namespace for storing cancel scope and event (key: UUID, val: Tuple[trio.CancelScope, trio.Event]
    2. Assign every client with unique UUID (Anything that is unique to client)
    3. Let client send UUID at start of connection
    4. Check if dictionary has that UUID as key. If exist, cancel the scope and wait for event to be set.
    5. Now do the actual transmission

    For scenario 2:

    Websocket doesn't know if client disconnected or not if client doesn't close websocket explicitly. Therefore best bet I can think of is enforcing Timeout and waiting for client's response on every transmission. (Which makes this method a bit less efficient).

    Probably better to make a periodic check with fail tolerance, like checking in every 5 mins and tolerating up to 2 successive timeout - but for simplicity will just enforce timeout for every transmission.


    Below is demonstrational code of above ideas.

    Client code:

    Since I don't know how client code looks like, I just made some client for testing your concerns.

    This is a bit buggy, but I didn't learned js - Please don't judge client code too seriously!

    <!DOCTYPE html>
    <html lang="en">
    <head>
        <meta charset="UTF-8">
        <title>Websocket test</title>
    </head>
    <body>
        <button id="start">Start connection</button>
        <button id="close" disabled>Close connection</button>
        <input type="text" id="input_" value="INPUT_YOUR_UUID">
    
        <div id="state">Status: Waiting for connection</div>
    
        <script>
            let state = document.getElementById("state")
            let start_btn = document.getElementById("start")
            let close_btn = document.getElementById("close")
            let input_ = document.getElementById("input_")
    
            function sleep(sec) {
                state.textContent = `Status: sleeping ${sec} seconds`
                return new Promise((func) => setTimeout(func, sec * 1000))
            }
    
            function websocket_test() {
                return new Promise((resolve, reject) => {
                    let socket = new WebSocket("ws://127.0.0.1:8000/stream")
    
                    socket.onopen = function () {
                        state.textContent = "Status: Sending UUID - " + input_.value
                        socket.send(input_.value)
                        close_btn.disabled = false
                        close_btn.onclick = function () {socket.close()}
                    }
                    socket.onmessage = function (msg) {
                        state.textContent = "Status: Message Received - " + msg.data
                        socket.send("Received")
                    }
                    socket.onerror = function (error) {
                        reject(error)
                        state.textContent = "Status: Error encountered"
                    }
                    socket.onclose = function () {
                        state.textContent = "Status: Connection Stopped"
                        close_btn.disabled = true
                    }
                })
            }
    
            start_btn.onclick = websocket_test
    
        </script>
    </body>
    </html>
    

    Server code:

    In previous testing I saw server throwing timeouts, but can't reproduce it - you might not need trio.fail_after and except trio.TooSlowError part if confident about behavior.

    """
    Nursery cancellation demo
    """
    import itertools
    from typing import Dict, Tuple
    
    import trio
    import fastapi
    import hypercorn
    from hypercorn.trio import serve
    
    
    GLOBAL_NURSERY_STORAGE: Dict[str, Tuple[trio.CancelScope, trio.Event]] = {}
    TIMEOUT = 5
    
    router = fastapi.APIRouter()
    
    
    @router.websocket('/stream')
    async def run_task(websocket: fastapi.WebSocket):
        # accept and receive UUID
        # Replace UUID with anything client-specific
        await websocket.accept()
        uuid_ = await websocket.receive_text()
    
        print(f"[{uuid_}] CONNECTED")
    
        # check if nursery exist in session, if exists, cancel it and wait for it to end.
        if uuid_ in GLOBAL_NURSERY_STORAGE:
            print(f"[{uuid_}] STOPPING NURSERY")
            cancel_scope, event = GLOBAL_NURSERY_STORAGE[uuid_]
            cancel_scope.cancel()
            await event.wait()
    
        # create new event, and start new nursery.
        cancel_done_event = trio.Event()
    
        async with trio.open_nursery() as nursery:
            # save ref
            GLOBAL_NURSERY_STORAGE[uuid_] = nursery.cancel_scope, cancel_done_event
    
            try:
                for n in itertools.count(0, 1):
                    nursery.start_soon(task, n, uuid_, websocket)
                    await trio.sleep(1)
    
                    # wait for client response
                    with trio.fail_after(TIMEOUT):
                        recv = await websocket.receive_text()
                        print(f"[{uuid_}] RECEIVED {recv}")
    
            except trio.TooSlowError:
                # client possibly left without proper disconnection, due to network issue
                print(f"[{uuid_}] CLIENT TIMEOUT")
    
            except fastapi.websockets.WebSocketDisconnect:
                # client performed proper disconnection
                print(f"[{uuid_}] CLIENT DISCONNECTED")
    
        # fire event, and pop reference if any.
        cancel_done_event.set()
        GLOBAL_NURSERY_STORAGE.pop(uuid_, None)
        print(f"[{uuid_}] NURSERY STOPPED & REFERENCE DROPPED")
    
    
    async def task(text, uuid_, websocket: fastapi.WebSocket):
        await websocket.send_text(str(text))
        print(f"[{uuid_}] SENT {text}")
    
    
    if __name__ == '__main__':
        cornfig = hypercorn.Config()
        # cornfig.bind = "ws://127.0.0.1:8000"
        trio.run(serve, router, cornfig)
    

    Example run output:

    Client

    enter image description here

    Server

    [2022-01-31 21:23:12 +0900] [17204] [INFO] Running on http://127.0.0.1:8000 (CTRL + C to quit)
    [2] CONNECTED      < start connection on tab 2
    [2] SENT 0
    [2] RECEIVED Received
    [2] SENT 1
    [2] RECEIVED Received
    [2] SENT 2
    [2] RECEIVED Received
    [2] SENT 3
    [2] RECEIVED Received
    [2] SENT 4
    [1] CONNECTED      < start connection on tab 1
    [1] SENT 0
    [2] RECEIVED Received
    [2] SENT 5
    [1] RECEIVED Received
    [1] SENT 1
    ...
    [2] SENT 18
    [1] RECEIVED Received
    [1] SENT 14
    [2] RECEIVED Received
    [2] SENT 19
    [1] CLIENT DISCONNECTED      < closed connection on tab 1
    [1] NURSERY STOPPED & REFERENCE DROPPED      < tab 1 nursery terminated
    [2] RECEIVED Received
    [2] SENT 20
    [2] RECEIVED Received
    [2] SENT 21
    [1] CONNECTED      < start connection on tab 1
    [1] SENT 0
    [2] RECEIVED Received
    [2] SENT 22
    [1] RECEIVED Received
    ...
    [2] SENT 26
    [1] RECEIVED Received
    [1] SENT 5
    [2] CLIENT DISCONNECTED      < tab 2 closed
    [2] NURSERY STOPPED & REFERENCE DROPPED      < tab 2 nursery terminated
    [1] RECEIVED Received
    [1] SENT 6
    [1] RECEIVED Received
    [1] SENT 7
    [1] RECEIVED Received
    [1] SENT 8
    [1] CONNECTED      < start another connection on tab 1 without closing
    [1] STOPPING NURSERY      < previous connection on tab 1 terminating
    [1] NURSERY STOPPED & REFERENCE DROPPED      < previous connection on tab 1 terminated
    [1] SENT 0
    [1] RECEIVED Received
    [1] SENT 1
    ...
    [1] RECEIVED Received
    [1] SENT 8
    [1] CLIENT DISCONNECTED      < Refreshed tab 1
    [1] NURSERY STOPPED & REFERENCE DROPPED      < tab 1 nursery terminated
    ...