Search code examples
pythonfastapiray

How to reconnect to ray cluster after the cluster restarted?


I have a question regarding the reconnection process between a ray cluster and a FastAPI server. On FastAPI I init/connect to the ray cluster in the startup event:

@app.on_event("startup")
async def init_ray():
    ...
    ray.init(address=f'{ray_head_host}:{ray_head_port}', _redis_password=ray_redis_password, namespace=ray_serve_namespace)

    ...

In the case of a restart of the ray cluster I ran into a problem when I want to use the ray API in some FastAPI routes:

Exception: Ray Client is not connected. Please connect by calling `ray.connect`.

So it seems that the connection from FastAPI to ray is lost (this is also confirmed by ray.is_initilized() ==> False). But if I try to re-connect using ray.init() I got the following error:

Exception: ray.connect() called, but ray client is already connected

I also tried to call ray.shutdown() infornt of the re-init call without success.

Maybe someone has an idea how to reconnect from FastAPI?


Solution

  • You can create a daemon thread that keeps checking the ray connection. If ray client is disconnected, reconnect by calling your startup function init_ray()

    import threading
    from ray.util.client import ray as ray_stub
    
    class RayConn(threading.Thread):
        def __init__(self):
            threading.Thread.__init__(self)
            self.daemon = True
            self.start()
    
        def run(self):
            while True:
                # sleep for 30 seconds
                time.sleep(30)
                if not ray_stub.is_connected():
                    logger.error("Ray client is disconnected. Trying to reconnect")
                    try:
                        try:
                            ray.shutdown()
                            logger.info("Shutdown complete.")
                        except BaseException as e:
                            logger.error(f"Failed to shutdown: {e}")
                        reestablish_conn() # your function that call ray.init() and task creation, if any
                        logger.info(f"Successfully reconnected, reconnect count: {reconnect_count}")
                    except BaseException as ee:
                        logger.error(f"Failed to to connect to ray head! {ee}")
    
    
    RayConn()