Quick warning: I've worked with websockets maybe twice, so I'm sure I'm missing something simple or what I have is too complicated (I'm also trying to keep track of connections, hence the connection manager you'll see below).
In any case, I'm trying to stream all stdout from my ML program (running on FastAPI) to a websocket so the stdout from keras can be displayed on a webpage. The stream is fine for simple strings (e.g. "starting training..."), but when I run model.fit, nothing is sent to the websocket until AFTER the model finishes training. When it does finish training, the entirety of the keras logs is sent, so it's like it's all stuck in the buffer. Here's what I have:
websocket route
@router.websocket('/ws')
async def start_websocket_logging(websocket: WebSocket):
await websocket_helper.start_websocket_logging(websocket)
websocket_helper.py
connection_manager = ConnectionManager()
socket_sleep = 0.5
async def redirect_std_out(websocket):
""" redirects the std output to the websocket """
stdout_buffer = StringIO()
async def send_stdout():
try:
while True:
await asyncio.sleep(socket_sleep)
data = stdout_buffer.getvalue()
if data:
await websocket.send_text(data.rstrip('\n'))
stdout_buffer.seek(0)
stdout_buffer.truncate()
except (asyncio.CancelledError, websockets.ConnectionClosedError):
# nothing needs to happen here
pass
sys.stdout = stdout_buffer
return asyncio.create_task(send_stdout())
async def stop_std_redirect(task: asyncio.Task):
""" resets the stdout to its original function and cancels the task sending data to the websocket """
sys.stdout = sys.__stdout__
task.cancel()
await asyncio.gather(task, return_exceptions=True)
async def start_websocket_logging(websocket):
""" starts the websocket between this service and the UI """
socket_id = str(uuid.uuid4())
await connection_manager.connect(websocket, socket_id)
redirect_task = await redirect_std_out(websocket)
try:
while True:
# force the socket to sleep to prevent it from crashing
await asyncio.sleep(socket_sleep)
except WebSocketDisconnect:
# ignore since we're disconnecting in the finally-block
pass
except Exception as e:
print(e)
finally:
print('websocket disconnected')
await stop_std_redirect(redirect_task)
await connection_manager.disconnect(socket_id)
ConnectionManager
class ConnectionManager:
def __init__(self):
self.open_sockets = {}
async def connect(self, websocket: WebSocket, socket_id: str):
await websocket.accept()
self.open_sockets[socket_id] = websocket
async def disconnect(self, socket_id: str):
websocket = self.open_sockets[socket_id]
del self.open_sockets[socket_id]
await websocket.close()
I've also tried manually flushing the buffers with a custom callback I pass to model.fit
's callback
param but no dice:
from keras.callbacks import Callback
from sys import stderr, stdout
class FlushStdIOCallback(Callback):
def on_epoch_begin(self, epoch, logs=None):
print(f'starting epoch {epoch}')
stderr.flush()
stdout.flush()
def on_epoch_end(self, epoch, logs=None):
print(f'finished epoch {epoch}')
stderr.flush()
stdout.flush()
I've looked at a few different implementations and they seem to be fairly similar.
I resolved the issue by making my training method asynchronous and running asyncio.sleep(0.5)
after running keras.model.fit
. It appears the issue was essentially due to the manual buffer flush being blocked by the synchronous training method.