Search code examples
tensorflowtensorflow-datasets

TF Dataset API: Is the following sequence correct? map,cache,shuffle,batch,repeat,prefetch


I am using this sequence to read images files from disk and feed into a TF Keras model.

  #Make dataset for training
    dataset_train = tf.data.Dataset.from_tensor_slices((file_ids_training,file_names_training))
    dataset_train = dataset_train.flat_map(lambda file_id,file_name: tf.data.Dataset.from_tensor_slices(
        tuple (tf.py_func(_get_data_for_dataset, [file_id,file_name], [tf.float32,tf.float32]))))
    dataset_train = dataset_train.cache()

    dataset_train= dataset_train.shuffle(buffer_size=train_buffer_size)
    dataset_train= dataset_train.batch(train_batch_size) #Make dataset, shuffle, and create batches
    dataset_train= dataset_train.repeat()
    dataset_train = dataset_train.prefetch(1)
    dataset_train_iterator = dataset_train.make_one_shot_iterator()
    get_train_batch = dataset_train_iterator.get_next()

I am having questions on whether this is the most optimal sequence. For e.g. Should repeat come after shuffle() and before batch()?, Should cache() come after batch?


Solution

  • The answer here Output differences when changing order of batch(), shuffle() and repeat() suggests repeat or shuffle before batching. The order I often use is (1) shuffle, (2) repeat, (3) map, (4) batch but it can vary based on your preferences. I use shuffle before repeat to avoid blurring epoch boundaries. I use map before batch because my mapping function applies to a single example (not to a batch of examples) but you can certainly write a map function that is vectorized and expects to see a batch as input.