Search code examples
pythontensorflowmultiprocessingtensorflow-datasets

multiprocessing.pool cannot be accelerated in tf.data


I want to use multiprocessing.pool in tf.data to speed up my augmentation function. But the result is slower than normal for loop.

multiprocessing.pool cost about: 72s

normal for loop cost about: 57s

My environment: python3.6, tensorflow-gpu2.4.0, Ubuntu20.04

Below is my code, what am I doing wrong?

Prerequisite thanks!

import numpy as np
import tensorflow as tf
from functools import partial
import multiprocessing

INPUT_SHAPE = (2000,6)
OUTPUT_SHAPE = (200,6)


def resizing(i ,data, enable, choice):
    if i==0:
        overlap=0
    else:
        overlap= 5 if enable >= 0.5 else 0
    if choice == 0:
        return [np.mean(data[i-overlap: i+10+overlap,0]),
                np.mean(data[i-overlap: i+10+overlap,1]),
                np.mean(data[i-overlap: i+10+overlap,2]),
                np.mean(data[i-overlap: i+10+overlap,3]),
                np.mean(data[i-overlap: i+10+overlap,4]),
                np.mean(data[i-overlap: i+10+overlap,5])]
    elif choice == 1:
        return [np.std(data[i-overlap: i+10+overlap,0]),
                np.std(data[i-overlap: i+10+overlap,1]),
                np.std(data[i-overlap: i+10+overlap,2]),
                np.std(data[i-overlap: i+10+overlap,3]),
                np.std(data[i-overlap: i+10+overlap,4]),
                np.std(data[i-overlap: i+10+overlap,5])]
    elif choice == 2:
        return [np.max(data[i-overlap: i+10+overlap,0]),
                np.max(data[i-overlap: i+10+overlap,1]),
                np.max(data[i-overlap: i+10+overlap,2]),
                np.max(data[i-overlap: i+10+overlap,3]),
                np.max(data[i-overlap: i+10+overlap,4]),
                np.max(data[i-overlap: i+10+overlap,5])]
    elif choice == 3:
        return [np.min(data[i-overlap: i+10+overlap,0]),
                np.min(data[i-overlap: i+10+overlap,1]),
                np.min(data[i-overlap: i+10+overlap,2]),
                np.min(data[i-overlap: i+10+overlap,3]),
                np.min(data[i-overlap: i+10+overlap,4]),
                np.min(data[i-overlap: i+10+overlap,5])]

def resize_data(data, pool_obj):

    choice = tf.random.uniform(shape=(), minval=0,maxval=4,dtype=tf.int64).numpy()
    enable = tf.random.uniform(shape=(), minval=0,maxval=1,dtype=tf.float64).numpy()
    new_data = pool_obj.map(partial(resizing, 
                                    data=data,
                                    enable=enable, 
                                    choice=choice), 
                            range(0,2000,10))
  # new_data = []
  # for i in range(0,2000,10):
  #     new_data.append(resizing(i ,data, enable, choice))
    
    return np.array(new_data)


def augmentation(data, labels, pool_obj):
    def aug(data):
        data = data.numpy()
        
        ...      

        # 2000 resize to 200
        data = resize_data(data, pool_obj)
        
        ...
        
        return tf.convert_to_tensor(data, tf.float64)

    data = tf.py_function(aug, [data], [tf.float64])[0]
    data.set_shape(OUTPUT_SHAPE)
    return data, labels

def test(trainDS):
    for d in trainDS:
        X, y = d
        print(i, X.shape, y.shape)
        

if __name__ == '__main__':
    pool_obj = multiprocessing.Pool()
    trainDS = tf.data.Dataset.from_tensor_slices(getDataSet_Path())
    trainDS = (
        trainDS
        .map(load_data, num_parallel_calls=tf.data.AUTOTUNE)
        .cache()
        .shuffle(300, reshuffle_each_iteration=False)
        .map(partial(augmentation, pool_obj=pool_obj), num_parallel_calls=tf.data.AUTOTUNE)
        .batch(128, drop_remainder=True)
        .prefetch(tf.data.AUTOTUNE)
    )
    
    test(trainDS)

Solution

  • TensorFlow Dataset API is already equipped with built in multiprocessing. Just use num_parallel_calls parameter in map and prefetch feature without any pythonic multiprocessing tools. Besides, pass only TensorFlow style functions to map that can be converted to graph. In particular, avoid using pythonic if blocks, try tf.cond, tf.where and etc. instead. Numpy routines are also not recommended, use TensorFlow analogous. Follow guides like this.