Search code examples
pythonredispytorchceleryopenai-whisper

Load and unload model on celery worker


I currently have a system that initiates tasks on whisper AI model using Celery. However, the existing setup involves loading the model inside each task, which is suboptimal due to the repeated loading process consuming a significant amount of time. Unfortunately, I cannot afford to keep the model loaded continuously, as other systems require access to the GPU VRAM.

To address this issue, I am considering implementing a system that loads the model upon receiving a task and unloads it once there are no tasks left. This approach aims to optimize resource utilization by minimizing the time the model spends loaded in memory, ensuring efficient use of GPU resources. I believe this adjustment will lead to improved overall system performance and responsiveness.

Here's the idea (not working obviously)

from celery import Celery, Task
from celery.signals import (
    task_received,
    celeryd_after_setup,
    task_success,
    task_failure,
)
import redis
import torch
import whisper
r = redis.Redis(host=hostname)

model = None

def checkActivesTasksWorker():
    global model
    if r.llen(activesTasksWorker) == 0:
        # Model is deleted when there is no task left on worker
        del model
        torch.cuda.empty_cache()

@celery.task
def myTask():
    launchAnalyticsWhisper(model)

@task_received.connect
def taskReceivedHandler(sender, request, **kwargs):
    if r.llen(activesTasksWorker) == 0:
        model = whisper.load_model("medium")
    r.lpush(activesTasksWorker, request.id)

@task_success.connect(sender=myTask)
def taskSuccessHandler(sender, result, **kwargs):
    r.lrem(activesTasksWorker, 1, result["taskId"])
    checkActivesTasksWorker()


@task_failure.connect(sender=myTask)
def taskFailureHandler(sender, task_id, exception, **kwargs):
    r.lrem(activesTasksWorker, 1, task_id)
    checkActivesTasksWorker()

@celeryd_after_setup.connect
def initList(sender, instance, **kwargs):
    # Clear the active tasks for the worker
    r.delete(activesTasksWorker)

model is None everywhere, so the global variable doesn't work. Do you have an idea to make this work ?


Solution

  • Celery spawns new processes for each task, which is why they don't have access to global variables.

    The best way to do this is to have a dedicated inference server for the model in question. The inference server loads the model once on start. The celery task makes requests to the inference server.

    If you have to juggle multiple models on the same GPU, you can add an endpoint to the inference server to move the model between CPU memory and GPU memory. That way you can keep GPU memory free while still only loading the model from disk once.