Search code examples
pythonwebsocketasync-awaitfastapijax

Websockets messages only sent at the end and not in instances using async / await, yield in nested for loops


I have a computationally heavy process that takes several minutes to complete in the server. So I want to send the results of every iteration to the client via websockets.

The overall application works but my problem is that all the messages are arriving at the client in one big chunk after the entire simulation finishes. I must be missing something here as I expect the await websocket.send_json() to send the message during the process and not all of them at the end.

Server python (FastAPI)

# A very simplified abstraction of the actual app.

def simulate_intervals(data):
  for t in range(data.n_intervals):
    state = interval(data) # returns a JAX NumPy array
    yield state

def simulate(data):
  for key in range(data.n_trials):
     trial = simulate_intervals(data)
     yield trial

@app.websocket("/ws")
async def socket(websocket: WebSocket):

  await websocket.accept()
  while True:
    # Get model inputs from client
    data = await websocket.receive_text()
    # Minimal computation
    nodes = distributions(data)

    nodosJson = json.dumps(nodes, cls=NumpyEncoder)
    # I expect this message to be sent early on,
    # but the client gets it at the end with all the other messages. 
    await websocket.send_json({"tipo": "nodos", "datos": json.loads(nodosJson)})
    
    # Heavy computation
    trials = simulate(data)

    for trialI, trial in enumerate(trials):
      for stateI, state in enumerate(trial):
        stateString = json.dumps(state, cls=NumpyEncoder)

        await websocket.send_json(
          {
            "tipo": "estado",
            "datos": json.loads(stateString),
            "trialI": trialI,
            "stateI": stateI,
          }
        )

    await websocket.send_json({"tipo": "estado", "msg": "fin"})

For completeness, here is the basic client code.

Client

const ws = new WebSocket('ws://localhost:8000/ws');

ws.onopen = () => {
  console.log('Conexión exitosa');
};

ws.onmessage = (e) => {
  const mensaje = JSON.parse(e.data);
  console.log(mensaje);
};

botonEnviarDatos.onclick = () => {
   ws.send(JSON.stringify({...}));
}

Solution

  • I was not able to make it work as posted in my question, still interested in hearing from anyone who understands why it is not possible to send multiple async messages without them getting blocked.

    For anyone interested, here is my current solution:

    Ping pong messages from client and server

    I changed the logic so the server and client are constantly sending each other messages and not trying to stream the data in a single request from the client.

    This actually works much better than my original attempt because I can detect when a sockets gets disconnected and stop processing in the server. Basically, if the client disconnects, no new requests for data are sent from that client and the server never continues the heavy computation.

    Server

    # A very simplified abstraction of the actual app.
    
    def simulate_intervals(data):
      for t in range(data.n_intervals):
        state = interval(data) # returns a JAX NumPy array
        yield state
    
    def simulate(data):
      for key in range(data.n_trials):
         trial = simulate_intervals(data)
         yield trial
    
    @app.websocket("/ws")
    async def socket(websocket: WebSocket):
    
      await websocket.accept()
      while True:
        # Get messages from client
        data = await websocket.receive_text()
        
        # "tipo" is basically the type of data being sent from client or server to the other one.
        # In this case, "tipo": "inicio" is the client sending inputs and requesting for a certain data in response.
        if data["tipo"] == "inicio":
          # Minimal computation
          nodes = distributions(data)
    
          nodosJson = json.dumps(nodes, cls=NumpyEncoder)
          # In this first interaction, the client gets the first message without delay. 
          await websocket.send_json({"tipo": "nodos", "datos": json.loads(nodosJson)})
    
          # Since this is a generator (def returns yield) it does not actually
          # trigger that actual computationally heavy process. 
          trials = simulate(data)
          
          # define some initial variables to count the iterations
          trialI = 0
          stateI = 0
          trialsLen = args.number_trials
          statesLen = 600
          
          # load the first trial (also a generator)
          # without the for loop used before, the counters and next()
          # allow us to do the same as being done before in the for loop
          trial = next(trials)
    
          # With the use of generators and next() it is possible to keep
          # this first message light on the server and send the first response
          # as quickly as possible.
        
        # This type of message asks for the next instance of the simluation
        # without processing the entire model.
        elif data["tipo"] == "sim":
          # check if we are within the limits (before this was a nested for loop)
          if trialI < trialsLen and stateI < statesLen:
            # Trigger the next instance of the simulation
            state = next(trial)
            # update counter
            stateI = stateI + 1
            
            # Send the message with 1 instance of the simulation.
            # 
            stateString = json.dumps(state, cls=NumpyEncoder)
            await websocket.send_json(
              {
                 "tipo": "estado",
                 "datos": json.loads(stateString),
                 "trialI": trialI,
                 "stateI": stateI,
              }
            )
            
            # Check if the second loop is done
            if stateI == statesLen:
              # update counter of first loop
              trialI = trialI + 1
              # update counter of second loop
              stateI = 0
              
              # Check if there are more pending trials,
              # otherwise stop and notify the client we are done.
              try:
                trial = next(trials)
              except StopIteration:
                await websocket.send_json({"tipo": "fin"})
    
    

    Client

    Just the part that actually changed:

    ws.onmessage = (e) => {
      const mensaje = JSON.parse(e.data);
      
      // Simply check the type of incoming message so it can be processed
      if (mensaje.tipo === 'fin') {
        viz.calcularResultados();
      } else if (mensaje.tipo === 'nodos') {
        viz.pintarNodos(mensaje.datos);
      } else if (mensaje.tipo === 'estado') {
        viz.sumarEstado(mensaje.datos);
      }
    
      // After receiving a message, ping the server for the next one 
      ws.send(
        JSON.stringify({
          tipo: 'sim',
        })
      );
    };
    

    This seems like reasonable solution to keep the server and client working together. I am able to show in the client the progress of a long simulation and the user experience is much better than having to wait for a long time for the server to respond. Hope it helps other with a similar problem.