Search code examples
pythontensorflowparallel-processingpipelinetf.data.dataset

TensorFlow 2.6: num_parallel_calls is greater than 1 but only one CPU core is used most of the time


I wrote a TF data pipeline that looks something like this (TF 2.6):

def parse(img):
    image = tf.image.decode_png(img, channels=3)
    image = tf.reshape(image, IMG_SHAPE)
    image = tf.cast(image, TARGET_DTYPE)
    return image


def decode_batch(serialized_example, is_test=False):
    feature_dict = {
        'image': tf.io.FixedLenFeature(shape=[], dtype=tf.string, default_value=''),
    }
    
    if not is_test:
        feature_dict["some_text"] = tf.io.FixedLenFeature(shape=[MAX_LEN], dtype=tf.int64, default_value=[0]*MAX_LEN)
    else:
        feature_dict["image_id"] = tf.io.FixedLenFeature(shape=[], dtype=tf.string, default_value='')

    features = tf.io.parse_example(tf.reshape(serialized_example, [BATCH_SIZE_OVERALL]), features=feature_dict)
    images = tf.map_fn(parse, features['image'], parallel_iterations=4, fn_output_signature=TARGET_DTYPE)

    if is_test:
        image_ids = features["image_id"] 
        return images, image_ids
    else:
        targets = tf.cast(features["some_text"], tf.uint8)
        return images, targets


def get_dataset(filenames, is_test):
    opts = tf.data.Options()
    opts.experimental_deterministic = False
    dataset = tf.data.Dataset.from_tensor_slices(filenames)
    dataset = dataset.with_options(opts)
    dataset = dataset.interleave(lambda x:
        tf.data.TFRecordDataset(x),
        cycle_length=4,
        num_parallel_calls=4,
    )
    dataset = dataset.batch(BATCH_SIZE_OVERALL, num_parallel_calls=4, drop_remainder=True)
    if not is_test:
        dataset = dataset.repeat()
        dataset = dataset.shuffle(BATCH_SIZE_OVERALL*6)
    dataset = dataset.map(lambda y: decode_batch(y, is_test), num_parallel_calls=4)

    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    
    return dataset


train_ds = get_dataset(TRAIN_TFREC_PATHS, False)

As you can see from the code, I did most of the tricks from the TF guide on correctly building the tf.data pipeline. The problem I have is the following: when starting the training, the code does not use all 4 cores, but only 1 (sometimes more cores are used, but it seems to be caused by train_dist_ds.get_next() call in the code below). Also, the GPU is almost not utilized at all. The profiler says that the problem is in preprocessing, and in tf_data_bottleneck_analysis it indicates that the problem is in ParallelBatch (although once he pointed to ParallelMap, which seems true, but this does not say much by itself - cores are still underutilized anyway). Training function with a profiler looks like this:

def fit_profile(train_ds, val_ds, stop_after_steps):
    tf.profiler.experimental.start('logdir')
    stat_logger.current_step = 0

    train_dist_ds = iter(train_ds)

    while True:
        stat_logger.batch_start_time = time.time()
        stat_logger.current_step += 1
        print(f'current step: {stat_logger.current_step}')
        with tf.profiler.experimental.Trace('train', step_num=stat_logger.current_step, _r=1):
            image_batch, some_text_batch = train_dist_ds.get_next()
        train_step(image_batch, some_text_batch)
        if stat_logger.current_step == stop_after_steps:
            break
            
    tf.profiler.experimental.stop()

As you can see, I don't touch the dataset, I don't put it into any strategy, it's in train_step (which is of course wrapped into @tf.function). Questions: is there a way to somehow debug calculations inside the graph for tf.data operations? In particular, at the level of calls to each tf.data API function inside preprocessing -- so that I can understand what exactly to optimize. What could be the reason that only one core is used?

What I've tried so far:

  • setting all autotunable parameters to tf.data.AUTOTUNE - no effect;
  • iterated over the dataset object alone -- all cores are used in this case, from which I conclude that the problem is on the graph execution level -- parallelism is not globally turned off;
  • turning off the profiler - no effect;
  • lowering the amount of parallel_iterations in map_fn call - no effect;
  • lots of weird settings to num_parallel_calls - no effect to the point that it seems like it really doesn't matter.

Solution

  • I finally found the reason for such behaviour. It was caused by using XLA with GPU.

    I suddenly found this, and decided to turn off the XLA, and oh god, after almost a week of investigations, GPU was fully utilized and training times became waaay more sane (before that they were equal to CPU training times!!). As it's written in the article: 1) GPU support in XLA is experimental; 2) tensors need to have inferrable shapes; 3) all operations in the graph must be supported in XLA. Signs of such problems are poor CPU and GPU utilization, as well as bouncing training steps, i.e. one step takes 150 seconds, and the next 8-10 steps take one second each, and then this pattern is repeated. The article talks about TF 1.x, but it seems that not much has changed regarding this topic up till now (again, I'm using TF 2.6).

    Main takeaways:

    1. Don't use XLA with GPU blindly, it may degrade your GPU training times down to CPU level (if used incorrectly).
    2. If you use XLA with GPU, make sure that you meet the requirements described above.

    I will update this answer if I manage to meet these XLA requirements in my computations and turn on the XLA with the performance boost, not degradation.