Search code examples
python-3.xredisclassificationhuggingface-transformersray

How to parallelize classification with Zero Shot Classification by Huggingface?


I have around 70 categories (it can be 20 or 30 also) and I want to be able to parallelize the process using ray but I get an error:

import pandas as pd
import swifter
import json
import ray
from transformers import pipeline

classifier = pipeline("zero-shot-classification")

labels = ["vegetables", "potato", "bell pepper", "tomato", "onion", "carrot", "broccoli",
          "lettuce", "cucumber", "celery", "corn", "garlic", "mashrooms", "cabbage", "spinach",
          "beans", "cauliflower", "asparagus", "fruits", "bananas", "apples", "strawberries",
          "grapes", "oranges", "lemons", "avocados", "peaches", "blueberries", "pineapple",
          "cherries", "pears", "mangoe", "berries", "red meat", "beef", "pork", "mutton",
          "veal", "lamb", "venison", "goat", "mince", "white meat", "chicken", "turkey",
          "duck", "goose", "pheasant", "rabbit", "Processed meat", "sausages", "bacon",
          "ham", "hot dogs", "frankfurters", "tinned meat", "salami", "pâtés", "beef jerky",
          "chorizo", "pepperoni", "corned beef", "fish", "catfish", "cod", "pangasius", "pollock",
          "tilapia", "tuna", "salmon", "seafood", "shrimp", "squid", "mussels", "scallop",
          "octopus", "grains", "rice", "wheat", "bulgur", "corn", "oat", "quinoa", "buckwheat",
          "meals", "salad", "soup", "steak", "pizza", "pie", "burger", "backery", "bread", "souce",
          "pasta", "sandwich", "waffles", "barbecue", "roll", "wings", "ribs", "cookies"]


ray.init()
@ray.remote
def get_meal_category(seq, labels, n=3):
    res_dict = classifier(seq, labels)
    return list(zip([seq for i in range(n)], res_dict["labels"][0:n], res_dict["scores"][0:n]))

res_list = ray.get([get_meal_category.remote(merged_df["title"][i], labels) for i in range(10)])

Where merged_df is a big dataframe with meal names in it's labels column like:

['Cappuccino',
 'Stove Top Stuffing Mix For Turkey (Kraft)',
 'Stove Top Stuffing Mix For Turkey (Kraft)',
 'Roasted Dark Turkey Meat',
 'Roasted Dark Turkey Meat',
 'Roasted Dark Turkey Meat',
 'Cappuccino',
 'Low Fat 2% Small Curd Cottage Cheese (Daisy)',
 'Rice Cereal (Gerber)',
 'Oranges']

Please advise how to avoid ray's error and parallelize the classification.

The error:

2021-02-17 16:54:51,689 WARNING worker.py:1107 -- Warning: The remote function __main__.get_meal_category has size 1630925709 when pickled. It will be stored in Redis, which could cause memory issues. This may mean that its definition uses a large array or other object.
---------------------------------------------------------------------------
ConnectionResetError                      Traceback (most recent call last)
~/.local/lib/python3.8/site-packages/redis/connection.py in send_packed_command(self, command, check_health)
    705             for item in command:
--> 706                 sendall(self._sock, item)
    707         except socket.timeout:

~/.local/lib/python3.8/site-packages/redis/_compat.py in sendall(sock, *args, **kwargs)
      8 def sendall(sock, *args, **kwargs):
----> 9     return sock.sendall(*args, **kwargs)
     10 

ConnectionResetError: [Errno 104] Connection reset by peer

During handling of the above exception, another exception occurred:

ConnectionError                           Traceback (most recent call last)
<ipython-input-9-1a5345832fba> in <module>
----> 1 res_list = ray.get([get_meal_category.remote(merged_df["title"][i], labels) for i in range(10)])

<ipython-input-9-1a5345832fba> in <listcomp>(.0)
----> 1 res_list = ray.get([get_meal_category.remote(merged_df["title"][i], labels) for i in range(10)])

~/.local/lib/python3.8/site-packages/ray/remote_function.py in _remote_proxy(*args, **kwargs)
     99         @wraps(function)
    100         def _remote_proxy(*args, **kwargs):
--> 101             return self._remote(args=args, kwargs=kwargs)
    102 
    103         self.remote = _remote_proxy

~/.local/lib/python3.8/site-packages/ray/remote_function.py in _remote(self, args, kwargs, num_returns, num_cpus, num_gpus, memory, object_store_memory, accelerator_type, resources, max_retries, placement_group, placement_group_bundle_index, placement_group_capture_child_tasks, override_environment_variables, name)
    205 
    206             self._last_export_session_and_job = worker.current_session_and_job
--> 207             worker.function_actor_manager.export(self)
    208 
    209         kwargs = {} if kwargs is None else kwargs

~/.local/lib/python3.8/site-packages/ray/function_manager.py in export(self, remote_function)
    142         key = (b"RemoteFunction:" + self._worker.current_job_id.binary() + b":"
    143                + remote_function._function_descriptor.function_id.binary())
--> 144         self._worker.redis_client.hset(
    145             key,
    146             mapping={

~/.local/lib/python3.8/site-packages/redis/client.py in hset(self, name, key, value, mapping)
   3048                 items.extend(pair)
   3049 
-> 3050         return self.execute_command('HSET', name, *items)
   3051 
   3052     def hsetnx(self, name, key, value):

~/.local/lib/python3.8/site-packages/redis/client.py in execute_command(self, *args, **options)
    898         conn = self.connection or pool.get_connection(command_name, **options)
    899         try:
--> 900             conn.send_command(*args)
    901             return self.parse_response(conn, command_name, **options)
    902         except (ConnectionError, TimeoutError) as e:

~/.local/lib/python3.8/site-packages/redis/connection.py in send_command(self, *args, **kwargs)
    723     def send_command(self, *args, **kwargs):
    724         "Pack and send a command to the Redis server"
--> 725         self.send_packed_command(self.pack_command(*args),
    726                                  check_health=kwargs.get('check_health', True))
    727 

~/.local/lib/python3.8/site-packages/redis/connection.py in send_packed_command(self, command, check_health)
    715                 errno = e.args[0]
    716                 errmsg = e.args[1]
--> 717             raise ConnectionError("Error %s while writing to socket. %s." %
    718                                   (errno, errmsg))
    719         except BaseException:

ConnectionError: Error 104 while writing to socket. Connection reset by peer.

Solution

  • This error is happening because of sending large objects to redis. merged_df is a large dataframe and since you are calling get_meal_category 10 times, Ray will attempt to serialize merged_df 10 times. Instead if you put merged_df into the Ray object store just once, and then pass along a reference to the object, this should work.

    EDIT: Since the classifier is also large, do something similar for that as well.

    Can you try something like this:

    ray.init()
    df_ref = ray.put(merged_df)
    model_ref = ray.put(classifier)
    
    @ray.remote
    def get_meal_category(classifier, df, i, labels, n=3):
        seq = df["title"][i]
        res_dict = classifier(seq, labels)
        return list(zip([seq for i in range(n)], res_dict["labels"][0:n], res_dict["scores"][0:n]))
    
    res_list = ray.get([get_meal_category.remote(model_ref, df_ref, i, labels) for i in range(10)])