Search code examples

Tensorflow Data API - prefetch

I am trying to use new features of TF, namely Data API, and I am not sure how prefetch works. In the code below

def dataset_input_fn(...)
    dataset =, compression_type="ZLIB")
    dataset = x:parser(...))
    dataset = x,y: image_augmentation(...)
                      , num_parallel_calls=num_threads

    dataset = dataset.shuffle(buffer_size)
    dataset = dataset.batch(batch_size)    
    dataset = dataset.repeat(num_epochs)
    iterator = dataset.make_one_shot_iterator()

does it matter between each lines above I put dataset=dataset.prefetch(batch_size)? Or maybe it should be after every operation that would be using output_buffer_size if the dataset was coming from


  • In discussion on github I found a comment by mrry:

    Note that in TF 1.4 there will be a Dataset.prefetch() method that makes it easier to add prefetching at any point in the pipeline, not just after a map(). (You can try it by downloading the current nightly build.)


    For example, Dataset.prefetch() will start a background thread to populate a ordered buffer that acts like a tf.FIFOQueue, so that downstream pipeline stages need not block. However, the prefetch() implementation is much simpler, because it doesn't need to support as many different concurrent operations as a tf.FIFOQueue.

    so it means prefetch could be put by any command and it works on the previous command. So far I have noticed the biggest performance gains by putting it only at the very end.

    There is one more discussion on Meaning of buffer_size in , Dataset.prefetch and Dataset.shuffle where mrry explains a bit more about the prefetch and buffer.

    UPDATE 2018/10/01:

    From version 1.7.0 Dataset API (in contrib) has an option to prefetch_to_device. Note that this transformation has to be the last in the pipeline and when TF 2.0 arrives contrib will be gone. To have prefetch work on multiple GPUs please use MultiDeviceIterator (example see #13610)