Search code examples
pythontensorflowtensorflow-datasetstensorflow-estimator

tf.estimator.Estimator gives different test accuracy when trianed epoch by epoch vs over all epochs


I have defined a straightforward CNN as my model_fn for a tf.estimator.Estimator and feed it with this input_fn:

def input_fn(features, labels, batch_size, epochs):
    dataset = tf.data.Dataset.from_tensor_slices((features))
    dataset = dataset.map(lambda x: tf.cond(tf.random_uniform([], 0, 1) > 0.5, lambda: dataset_augment(x), lambda: x),
                          num_parallel_calls=16).cache()
    dataset_labels = tf.data.Dataset.from_tensor_slices((labels))
    dataset = dataset.zip((dataset, dataset_labels))
    dataset = dataset.shuffle(30000)
    dataset = dataset.repeat(epochs)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(-1)
    return dataset

when I train the estimator this way, I get 43% test accuracy after 10 epochs:

steps_per_epoch = data_train.shape[0] // batch_size
for epoch in range(1, epochs + 1):
    cifar100_classifier.train(lambda: input_fn(data_train, labels_train, batch_size, epochs=1), steps=steps_per_epoch)

But when I train it this way I get 32% test accuracy after 10 epochs:

steps_per_epoch = data_train.shape[0] // batch_size
max_steps = epochs * steps_per_epoch
cifar100_classifier.train(steps=max_steps,
                              input_fn=lambda: input_fn(data_train, labels_train, batch_size, epochs=epochs))

I just cannot understand why these two methods produce different results. Can anyone please explain?


Solution

  • Since you are calling the input_fn multiple times in the first example it seems like you would be generating more augmented data through dataset_augment(x) as you're doing an augmentation coin-toss for every x every epoch.

    In the second example you only do these coin-tosses once and then train multiple epochs on that same data. So here your train set is effectively ``smaller''.

    The .cache() doesn't really save you from this in the first example.