Search code examples
pythonflaskpython-asyncioquart

Websockets with the python web framework "quart"?


I need help with the python web frame work, Quart, more specifically the websockets. I would like to be able to register a client when it connects (add it to a python list), and unregister them (remove it from the python list) when it disconnects. The closest thing I could find on the web is this code:

connected = set()

async def handler(websocket, path):
    global connected
    # Register.
    connected.add(websocket)
    try:
        # Implement logic here.
        await asyncio.wait([ws.send("Hello!") for ws in connected])
        await asyncio.sleep(10)
    finally:
        # Unregister.
        connected.remove(websocket)

source

But this does not work with quart websockets.

Help would be appreciated.


Solution

  • This decorator when used to wrap a websocket handler, will add and remove websockets from the connected set. The _get_current_object method of the websocket is required to get the websocket in the current context, and the try-finally is required to ensure the websocket is removed regardless of any errors that are raised. Note the app.websocket must wrap (be before) the collect_websocket usage.

    from functools import wraps
    
    connected = set()
    
    def collect_websocket(func):
        @wraps(func)
        async def wrapper(*args, **kwargs):
            global connected
            connected.add(websocket._get_current_object())
            try:
                return await func(*args, **kwargs)
            finally:
                connected.remove(websocket._get_current_object())
        return wrapper                                                                                                                                                                                                            
    
    
    @app.websocket('/ws')                                                                                                                                                                                       
    @collect_websocket
    async def ws():
        ...
    

    Edit: I am the Quart author.