Search code examples
pythontensorflowtensorflow2.0tensorflow-datasets

Efficient way to iterate over tf.data.Dataset


I want to know which is the most efficient way to iterate through a tf.data.Dataset in TensorFlow 2.4.

I am using the typical:

for example in dataset:
    code

However, I have measured the wall time and, since my dataset is huge, it takes too much time for computing the loop. Is there any other option that reduces the computing time?.


Solution

  • You can use .map(map_func) function which is an efficient way to apply some preprocessing on each sample in your dataset. It runs the map_func on each sample of your dataset in parallel. You can even set number of parallel calls by num_parallel_calls argument. [Reference]

    Here is an example from tensorflow website:

    dataset = tf.data.Dataset.range(1, 6)  # ==> [ 1, 2, 3, 4, 5 ]
    dataset = dataset.map(lambda x: x + 1) # instead of adding 1 to each sample in a for loop
    list(dataset.as_numpy_iterator())      # ==> [ 2, 3, 4, 5, 6 ]
    

    You can pass a function as well:

    def my_map(x): # if dataset has y, it should be like "def my_map(x,y)" and "return x,y"
      return x+1  
                                                      
    dataset = tf.data.Dataset.range(1, 6)  # ==> [ 1, 2, 3, 4, 5 ]
    dataset = dataset.map(my_map)          # instead of adding 1 to each sample in a for loop
    list(dataset.as_numpy_iterator())      # ==> [ 2, 3, 4, 5, 6 ]