There is a nice way to set epochs when feeding numpy arrays into an estimator
tf.estimator.inputs.numpy_input_fn(
x,
y=None,
batch_size=128,
num_epochs=1 ,
shuffle=None,
queue_capacity=1000,
num_threads=1
)
But I can't track down a similar method with TFRecords, most people seem to just stick it in a loop
i = 0
while ( i < 100000):
model.train(input_fn=input_fn, steps=100)
Is there a clean way to explicitly set the number of epochs for TFRecords with estimators ?
You can set number of epoch with dataset.repeat(num_epochs)
. Dataset pipeline outputs a dataset object, a tuple (features, labels) of batch size, that is inputed to model.train()
dataset = tf.data.TFRecordDataset(file.tfrecords)
dataset = tf.shuffle().repeat()
...
dataset = dataset.batch()
In order to make it work, you set model.train(steps=None, max_steps=None)
In this case, you let Dataset API to handle epochs count by generating tf.errors.OutOfRange
error or StopIteration
exception once num_epoch is reached.