Search code examples
postgresqlwebsocketpython-asynciofastapi

FastAPI: Permanently running background task that listens to Postgres notifications and sends data to websocket


Minimal reproducible example:

import asyncio
import aiopg
from fastapi import FastAPI, WebSocket


dsn = "dbname=aiopg user=aiopg password=passwd host=127.0.0.1"
app = FastAPI()


class ConnectionManager:
    self.count_connections = 0
    # other class functions and variables are taken from FastAPI docs
    ...


manager = ConnectionManager()


async def send_and_receive_data(websocket: WebSocket):
    data = await websocket.receive_json()
    await websocket.send_text('Thanks for the message')
    # then process received data


# taken from official aiopg documentation
# the function listens to PostgreSQL notifications
async def listen(conn):
    async with conn.cursor() as cur:
        await cur.execute("LISTEN channel")
        while True:
            msg = await conn.notifies.get()


async def postgres_listen():
    async with aiopg.connect(dsn) as listenConn:
        listener = listen(listenConn)
        await listener


@app.get("/")
def read_root():
    return {"Hello": "World"}


@app.websocket("/")
async def websocket_endpoint(websocket: WebSocket):
    await manager.connect(websocket)
    manager.count_connections += 1

    if manager.count_connections == 1:
        await asyncio.gather(
            send_and_receive_data(websocket),
            postgres_listen()
        )
    else:
        await send_and_receive_data(websocket)

Description of the problem:

I am building an app with Vue.js, FastAPI and PostgreSQL. In this example I attempt to use listen/notify from Postgres and implement it in the websocket. I also use a lot of usual http endpoints along with the websocket endpoint.

I want to run a permanent background asynchronous function at the start of the FastAPI app that will then send messages to all websocket clients/connections. So, when I use uvicorn main:app it should not only run the FastAPI app but also my background function postgres_listen(), which notifies all websocket users, when a new row is added to the table in the database.

I know that I can use asyncio.create_task() and place it in the on_* event, or even place it after the manager = ConnectionManager() row, but it will not work in my case! Because after any http request (for instance, read_root() function), I will get the same error described below.

You see that I use a strange way to run my postgres_listen() function in my websocket_endpoint() function only when the first client connects to the websocket. Any subsequent client connection does not run/trigger this function again. And everything works fine... until the first client/user disconnects (for example, closes browser tab). When it happens, I immediately get the GeneratorExit error caused by psycopg2.OperationalError:

Future exception was never retrieved
future: <Future finished exception=OperationalError('Connection closed')>
psycopg2.OperationalError: Connection closed
Task was destroyed but it is pending!
task: <Task pending name='Task-18' coro=<Queue.get() done, defined at 
/home/user/anaconda3/lib/python3.8/asyncio/queues.py:154> wait_for=<Future cancelled>>

The error comes from the listen() function. After this error, I will not get any notification from the database as the asyncio's Task is cancelled. There is nothing wrong with the psycopg2, aiopg or asyncio. The problem is that I don't understand where to put the postgres_listen() function so it will not be cancelled after the first client disconnects. From my understanding, I can easily write a python script that will connect to the websocket (so I will be the first client of the websocket) and then run forever so I will not get the psycopg2.OperationalError exception again, but it does not seem right to do so.

My question is: where should I put postgres_listen() function, so the first connection to websocket may be disconnected with no consequences?

P.S. asyncio.shield() also does not work


Solution

  • I have answered this on Github as well, so I am reposting it here.

    A working example can be found here: https://github.com/JarroVGIT/fastapi-github-issues/tree/master/5015

    # app.py
    import queue
    from typing import Any
    from fastapi import FastAPI, WebSocket, WebSocketDisconnect
    from asyncio import Queue, Task
    import asyncio
    
    import uvicorn
    import websockets
    
    class Listener:
        def __init__(self):
            #Every incoming websocket conneciton adds it own Queue to this list called 
            #subscribers.
            self.subscribers: list[Queue] = []
            #This will hold a asyncio task which will receives messages and broadcasts them 
            #to all subscribers.
            self.listener_task: Task
    
        async def subscribe(self, q: Queue):
            #Every incoming websocket connection must create a Queue and subscribe itself to 
            #this class instance 
            self.subscribers.append(q)
    
    
        async def start_listening(self):
            #Method that must be called on startup of application to start the listening 
            #process of external messages.
            self.listener_task = asyncio.create_task(self._listener())
    
        async def _listener(self) -> None:
            #The method with the infinite listener. In this example, it listens to a websocket
            #as it was the fastest way for me to mimic the 'infinite generator' in issue 5015
            #but this can be anything. It is started (via start_listening()) on startup of app.
            async with websockets.connect("ws://localhost:8001") as websocket:
                async for message in websocket:
                    for q in self.subscribers:
                        #important here: every websocket connection has its own Queue added to
                        #the list of subscribers. Here, we actually broadcast incoming messages
                        #to all open websocket connections.
                        await q.put(message)
    
        async def stop_listening(self):
            #closing off the asyncio task when stopping the app. This method is called on 
            #app shutdown
            if self.listener_task.done():
                self.listener_task.result()
            else:
                self.listener_task.cancel()
    
        async def receive_and_publish_message(self, msg: Any):
            #this was a method that was called when someone would make a request 
            #to /add_item endpoint as part of earlier solution to see if the msg would be 
            #broadcasted to all open websocket connections (it does)
            for q in self.subscribers:
                try:
                    q.put_nowait(str(msg))
                except Exception as e:
                    raise e
    
        #Note: missing here is any disconnect logic (e.g. removing the queue from the list of subscribers
        # when a websocket connection is ended or closed.)
    
            
    global_listener = Listener()
    
    app = FastAPI()
    
    @app.on_event("startup")
    async def startup_event():
        await global_listener.start_listening()
        return
    
    @app.on_event("shutdown")
    async def shutdown_event():
        await global_listener.stop_listening()
        return
    
    
    @app.get('/add_item/{item}')
    async def add_item(item: str):
        #this was a test endpoint, to see if new items where actually broadcasted to all 
        #open websocket connections.
        await global_listener.receive_and_publish_message(item)
        return {"published_message:": item}
    
    @app.websocket("/ws")
    async def websocket_endpoint(websocket: WebSocket):
        await websocket.accept()
        q: asyncio.Queue = asyncio.Queue()
        await global_listener.subscribe(q=q)
        try:
            while True:
                data = await q.get()
                await websocket.send_text(data)
        except WebSocketDisconnect:
                return
    
    
    if __name__ == "__main__":
        uvicorn.run(app, host="0.0.0.0", port=8000)
    

    As I didn't have access to a stream of message I could have subscribed to, I created a quick script that produces a websocket, so that the app.py above could listen to that (indefinitely) to mimic your use case.

    # generator.py
    from fastapi import FastAPI, WebSocket, WebSocketDisconnect
    import asyncio
    import uvicorn
    
    
    app = FastAPI()
    
    @app.websocket("/")
    async def ws(websocket: WebSocket):
        await websocket.accept()
        i = 0
        while True:
            try:
                await websocket.send_text(f"Hello - {i}")
                await asyncio.sleep(2)
                i+=1
            except WebSocketDisconnect:
                pass
    
    if __name__ == "__main__":
        uvicorn.run(app, host="0.0.0.0", port=8001)
    

    The app.py will listen to a websocket and publishes all incoming messages to all connections to the websockets in app.py.

    The generator.py is a simple FastAPI app that has a websocket (that our example app.py above listens to) that emits a message every 2 seconds to every connection it gets.

    To try this out:

    • Start generator.py (e.g. python3 generator.py on your command line when in your working folder)
    • Start app.py (either debug mode in VScode or same as above)
    • Listen to http://localhost:8000/ws (= endpoint in app.py) with several clients, you will see that they will all join in the same message streak.

    NOTE: lots of this logic was inspired by Broadcaster (a python module)