Search code examples
loggingwebsocketfastapisolverpyomo

Pyomo - Streaming model solver output via Websocket with FastAPI


I'm creating an optimization model with pyomo and trying to expose it via FastAPI (plain FastAPI or celery). In this process I'm trying to capture the solver output to stream it near-real-time to user to inspect it. To stream this data, I'm trying to use WebSockets

An example for this app can be as follows: Structure

|--main.py
|--model.py
|--template.py

model.py

import pyomo.environ as pyo

model = pyo.ConcreteModel()

model.t = pyo.RangeSet(10000)
model.x = pyo.Var(model.t, domain=pyo.NonNegativeReals)

def constraint(model):
    return sum(model.x[t] for t in model.t) >= 10


model.c1 = pyo.Constraint(rule=constraint)

model.obj = pyo.Objective(
    expr=sum(model.x[t] for t in model.t),
    sense=pyo.minimize
)

main.py

import logging

from fastapi import FastAPI, WebSocket
from fastapi.responses import HTMLResponse
from pyomo.common.tee import capture_output
from pyomo.environ import SolverFactory

from model import instance
from template import template

logger = logging.getLogger(mod_name="simple-example-logger")

app = FastAPI()

@app.get("/")
async def get():
    return HTMLResponse(template)


@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
    await websocket.accept()
    # --- Solve Model ---       
    while True:
        data = await websocket.receive_text()
        await websocket.send_text(f"Optimizing...")
        solver = SolverFactory("cbc")
        # solver.solve(instance, tee=True, logfile="myfile.log")
        print("Loading...")
        with capture_output() as LOGER:
            solver_logger = logging.StreamHandler(LOGER)
            logger.addHandler(solver_logger)
            solver.solve(instance, tee=True)
            text = solver_logger.stream.getvalue()
            print(text)
            replace_string = r"\n"
            await websocket.send_text(f"Message text was: {solver_logger.stream.getvalue().replace(replace_string, '</br>')}")`

template.py

This is only an .py file with an string with html data to stream WS

template = """
<!DOCTYPE html>
<html>
    <head>
        <title>Chat</title>
    </head>
    <body>
        <h1>WebSocket Chat</h1>
        <form action="" onsubmit="sendMessage(event)">
            <input type="text" id="messageText" autocomplete="off"/>
            <button>Send</button>
        </form>
        <ul id='messages'>
        </ul>
        <script>
            var ws = new WebSocket("ws://localhost:8000/ws");
            ws.onmessage = function(event) {
                var messages = document.getElementById('messages')
                var message = document.createElement('li')
                var content = document.createTextNode(event.data)
                message.appendChild(content)
                messages.appendChild(message)
            };
            function sendMessage(event) {
                var input = document.getElementById("messageText")
                ws.send(input.value)
                input.value = ''
                event.preventDefault()
            }
        </script>
    </body>
</html>
"""

If i run uvicorn main:app --reload I'm able to stream the solver log into the template, but this only happen when the solving process had finish. For a much-larger model that takes long to solve, I want to watch the logs as near-real-time as possible.

Any hint on this approach?


Solution

  • I finally could handle this. I'm not that sure that this is the best approach, but this works

    Thanks to @Chris for the comment to This post.

    I also used This answer to achieve this streaming

    I decide to "stream" a "batch" of logs to avoid overhead since some solver (e.g., scip) generate a lot of verbose while optimizing, Then I stream a batch every 5 seconds.

    I generate a class to stream replacing the sys.stdout. The method _write_by_batches generates a new event loop to allow asynchronously streaming throuth WebSocket generating a new event loop each time that log batch is streamed. Since I want to avoid overhead of streaming each time log

    class MyStdOut(TextIOBase):
        def __init__(self, ws: WebSocket, orig_stdout=None,):
            self._ws = ws
            self.orig_stdout = orig_stdout
            self._last_update = datetime.now()
            self._batch = ""
            
        def write(self, s):
            self._stream_solver_log(
                current_time=datetime.now(),
                msg=s
            )
            if self.orig_stdout:
                self.orig_stdout.write(s)
    
        def _write_by_batches(self, s):
            loop = asyncio.new_event_loop()
            loop.run_until_complete(
                self._ws.send_text(s)
            )
    
        def _stream_solver_log(self, current_time: datetime, msg: str):
            if (current_time - self._last_update).total_seconds() >= 5:
                self._write_by_batches(self._batch)
                self._batch = ""
                self._last_update = current_time
            else:
                self._batch = f"{self._batch}\n{msg}"
    
    

    Then in the websocket endpoint, I get the running loop in order to run the solver.solve(instance, tee=True) using loop.run_in_executor.

    The whole main.py script is as follows:

    main.py

    from io import TextIOBase
    import asyncio
    from datetime import datetime
    import sys
    
    from fastapi import FastAPI, WebSocket
    from fastapi.responses import HTMLResponse
    from pyomo.environ import SolverFactory
    
    from model import instance
    from template import template
    
    
    class MyStdOut(TextIOBase):
        def __init__(self, ws: WebSocket, orig_stdout=None,):
            self._ws = ws
            self.orig_stdout = orig_stdout
            self._last_update = datetime.now()
            self._batch = ""
            
        def write(self, s):
            # new_string = str(s).replace("\n", "/n")
            self._stream_solver_log(
                current_time=datetime.now(),
                msg=s
            )
            if self.orig_stdout:
                self.orig_stdout.write(s)
    
        def _write_by_batches(self, s):
            loop = asyncio.new_event_loop()
            loop.run_until_complete(
                self._ws.send_text(s)
            )
    
        def _stream_solver_log(self, current_time: datetime, msg: str):
            if (current_time - self._last_update).total_seconds() >= 5:
                self._write_by_batches(self._batch)
                self._batch = ""
                self._last_update = current_time
            self._batch = f"{self._batch}\n{msg}"
        
            
    def solve_model(solver, model, options):
        return solver.solve(model, **options)
    
    
    app = FastAPI()
    
    @app.get("/")
    async def get():
        return HTMLResponse(template)
    
    @app.websocket("/ws")
    async def websocket_endpoint(websocket: WebSocket):
        await websocket.accept()
        # --- Solve Model ---
        solver = SolverFactory("scip")
        original_stdout = sys.stdout  
        sys.stdout = MyStdOut(ws=websocket, orig_stdout=sys.stdout)
        while True:
            data = await websocket.receive_text()
            await websocket.send_text(f"Optimizing...")
            loop = asyncio.get_running_loop()
            await asyncio.gather(*[
                loop.run_in_executor(None, solve_model, *[solver, instance, {"tee": True}]),
            ])
            await websocket.send_text(f"{sys.stdout._batch} \n\n\nOptimization finished")
            sys.stdout = original_stdout