Search code examples
pythontensorflowbatching

Reading TFRecords with tf.data.Dataset api increases computation time


My data is in a tfrecords file. This simple code iterates and batches the images with the tf.data.Dataset api. Yet, the computation time per 100 batches increases. Why is this so and how to fix this?

import tensorflow as tf
import time
sess = tf.Session()
dataset = tf.data.TFRecordDataset('/tmp/data/train.tfrecords')
dataset = dataset.repeat()
dataset = dataset.batch(3)
iterator = dataset.make_one_shot_iterator()

prev_step = time.time()
for step in range(10000):
    tensors = iterator.get_next()
    fetches = sess.run(tensors)
    if step % 200 == 0:
        print("Step %6i time since last %7.5f" % (step, time.time() - prev_step))
        prev_step = time.time()

This outputs the following times:

Step      0 time since last 0.01432
Step    200 time since last 1.85303
Step    400 time since last 2.15448
Step    600 time since last 2.65473
Step    800 time since last 3.15646
Step   1000 time since last 3.72434
Step   1200 time since last 4.34447
Step   1400 time since last 5.11210
Step   1600 time since last 5.87102
Step   1800 time since last 6.61459
Step   2000 time since last 7.57238
Step   2200 time since last 8.33060
Step   2400 time since last 9.37795      

The tfrecords file contains MNIST images, written with this HowTo from the Tensorflow doc's

To narrow the problem scope, I reproduced the code to read the raw images from disk. In that case, the time per 200 batches stays constant as expected.

Now my question is:

  • What part of the code increases the computation time?
  • Should I file this as a bug in the Tensorflow github?

Solved!

Answer to my own question: move get_next() outside the loop


Solution

  • Solved: Move get_next() outside the loop