Search code examples
pythondjangoscikit-learncelerypython-3.7

Celery task with Scikit-Learn doesn't use more than a single core


I am trying to create a an API endpoint that will start a classification task asynchronously in a Django backend and I want to be able to retrieve the result later on. This is what I have done so far:

celery.py


import os
from celery import Celery

os.environ.setdefault("DJANGO_SETTINGS_MODULE", "backend.settings")

app = Celery("backend")
app.config_from_object("django.conf:settings", namespace = "CELERY")
app.autodiscover_tasks()

tasks.py

from celery import shared_task

@shared_task
def Pipeline(taskId):
  # ...read file, preprocess, train_test_split
  
  clf = GridSearchCV(
      SVC(), paramGrid, cv=5, n_jobs = -1
    )
  clf.fit(XTrain, yTrain)
  
  # ...compute the metrics

Django view for executing the task: views.py

class TaskExecuteView(APIView):
  def get(self, request, taskId, *args, **kwargs):
    try:
      task = TaskModel.objects.get(taskId = taskId)
    except TaskModel.DoesNotExist:
      raise Http404
    else:
      Pipeline.delay(taskId)
      # ... Database updates

The problem is the started task only uses one core of the CPU, hence taking a very long time to complete. I also see this error: Warning: Loky-backed parallel loops cannot be called in a multiprocessing, setting n_jobs=1. Is there a way to solve this?

I am aware of a similar question on SO made in 2018 that is kind of similar to this, but that post has no definite answer, so I am looking for a solution with no luck so far. Thanks for your time and any suggestions/solutions that you may want to spare, although I don't really want to change the tech stack unless I can avoid it.

What I have tried so far:

Using threading.current_thread().setName("MainThread") as the first line in the celery task, but that did not work.

EDIT: requirements.txt

amqp==5.0.2
asgiref==3.3.1
billiard==3.6.3.0
celery==5.0.5
certifi==2020.12.5
cffi==1.14.5
chardet==4.0.0
click==7.1.2
click-didyoumean==0.0.3
click-plugins==1.1.1
click-repl==0.1.6
cryptography==3.4.6
defusedxml==0.7.1
dj-rest-auth==2.1.3
Django==3.1.3
django-allauth==0.44.0
django-cors-headers==3.6.0
djangorestframework==3.12.2
djangorestframework-simplejwt==4.6.0
idna==2.10
importlib-metadata==3.3.0
joblib==1.0.0
kombu==5.0.2
mccabe==0.6.1
numpy==1.19.4
oauthlib==3.1.0
pandas==1.2.0
Pillow==8.0.1
prompt-toolkit==3.0.8
pycodestyle==2.6.0
pycparser==2.20
PyJWT==2.0.1
python-dateutil==2.8.1
python3-openid==3.2.0
pytz==2020.4
redis==3.5.3
requests==2.25.1
requests-oauthlib==1.3.0
scikit-learn==0.24.0
scipy==1.6.0
sqlparse==0.4.1
threadpoolctl==2.1.0
typed-ast==1.4.1
typing-extensions==3.7.4.3
urllib3==1.26.4
vine==5.0.0
wcwidth==0.2.5
wrapt==1.11.2
zipp==3.4.0

Solution

  • I solved this issue by switching over to django_rq (Github link to the project).

    I don't understand the concepts of parallel/distributed computing that much myself, but the issue was that celery tasks do not support multiprocessing inside them. So essentially, I can't spawn other processes inside a process that's daemonic.

    I don't know how django_rq handles this, but I only changed two lines in the code and that solved the issue.

    from django_rq import job
    
    @job('default', timeout=3600)  <--- changed here
    def Pipeline(taskId):
      # ...read file, preprocess, train_test_split
      
      clf = GridSearchCV(
          SVC(), paramGrid, cv=5, n_jobs = -1
        )
      clf.fit(XTrain, yTrain)
      
      # ...compute the metrics
    
    

    ... and rest of the API remained the same.

    I will update this answer once I understand the core concepts in parallel computation and why celery failed where django_rq succeeded.