Search code examples
pandaskerasmultiprocessingpoolweak-references

Parallelizing Keras Model Predict Using Multiprocessing


I have a system with 60 CPUs. I intend to parallelize the prediction of a Keras model on several images. I tried the following code:

img_model1 = tensorflow.keras.models.load_model('my_model.h5')
img_model2 = tensorflow.keras.models.load_model('my_model.h5')
img_model3 = tensorflow.keras.models.load_model('my_model.h5')
models=[img_model1,img_model2,img_model3] # all the three are same models

I tried to use indices to avoid weakref pickling error:

def _apply_df(args):
    df, model_index = args
    preds=prediction_generator(df['image_path'])
    return  models[model_index].predict(preds)

def apply_by_multiprocessing(df, workers):
    workers = workers
    pool = Pool(processes=workers)
    result = pool.map(_apply_df, [(d,i) for i,d in enumerate(np.array_split(df[['image_path']], workers))])
    pool.close()
    return pd.concat(list(result))
    

apply_by_multiprocessing(df=data, workers=3) 

The code keeps running forever without yielding any results... I guess the problem could be solved with tf.Sessions(), but I'm not sure how...


Solution

  • Load your model in the _apply_df function, so it doesn't get involved in pickling and sending to the process.

    This is a simple code example without the use of pandas that runs a model on fashion-mnist data. I think you can adapt it to your use case.

    import tensorflow as tf
    import numpy as np
    from multiprocessing import Pool
    
    
    def _apply_df(data):
        model = tf.keras.models.load_model("my_fashion_mnist_model.h5")
        return model.predict(data)
    
    
    def apply_by_multiprocessing(data, workers):
    
        pool = Pool(processes=workers)
        result = pool.map(_apply_df, np.array_split(data, workers))
        pool.close()
        return list(result)
    
    
    def main():
        fashion_mnist = tf.keras.datasets.fashion_mnist
        _, (test_images, test_labels) = fashion_mnist.load_data()
    
        test_images = test_images / 255.0
        results = apply_by_multiprocessing(test_images, workers=3)
        print(test_images.shape)           # (10000, 28, 28)
        print(len(results))                # 3
        print([x.shape for x in results])  # [(3334, 10), (3333, 10), (3333, 10)]
    
    
    if __name__ == "__main__":
        main()