Search code examples
tensorflowtfrecord

Is there a simple way to set epochs when using TFRecords with Tensorflow Estimators


There is a nice way to set epochs when feeding numpy arrays into an estimator

  tf.estimator.inputs.numpy_input_fn(
     x,
     y=None,
     batch_size=128,
     num_epochs=1 ,
     shuffle=None,
     queue_capacity=1000,
     num_threads=1  
   )

But I can't track down a similar method with TFRecords, most people seem to just stick it in a loop

 i = 0 
 while ( i < 100000):
   model.train(input_fn=input_fn, steps=100)

Is there a clean way to explicitly set the number of epochs for TFRecords with estimators ?


Solution

  • You can set number of epoch with dataset.repeat(num_epochs). Dataset pipeline outputs a dataset object, a tuple (features, labels) of batch size, that is inputed to model.train()

    dataset = tf.data.TFRecordDataset(file.tfrecords)
    dataset = tf.shuffle().repeat()
    ...
    dataset = dataset.batch()
    

    In order to make it work, you set model.train(steps=None, max_steps=None) In this case, you let Dataset API to handle epochs count by generating tf.errors.OutOfRange error or StopIteration exception once num_epoch is reached.