Search code examples
tensorflowtensorflow-estimator

`estimator.train` with num_steps in Tensorflow


I have made a custom estimator in Tensorflow 1.4. In estimator.trainfunction, I see a steps parameter, which I am using as a way to stop the training and then evaluate on my validation dataset.

while True:
    model.train(input_fn= lambda:train_input_fn(train_data), steps = FLAGS.num_steps)
    model.evaluate(input_fn= lambda:train_input_fn(test_data))

After every num_steps, I run evaluate on validation dataset. What I am observing is, after num_steps, once the evaluation is done, there is a jerk in the plot of AUC/Loss functions(in general all metric).

Plot attached : enter image description here

I am unable to understand why it's happening.

Is it not the right way to evaluate metrics on validation dataset at regular intervals

Link to code


Solution

  • The issue

    The issue comes from the fact that what you plot in TensorBoard is the accuracy or AUC computed since the beginning of estimator.train.

    Here is what happens in details:

    • you create a summary based on the second output of tf.metrics.accuracy
    accuracy = tf.metrics.accuracy(labels, predictions)
    tf.summary.scalar('accuracy', accuracy[1])
    
    • when you call estimator.train(), a new Session is created and all the local variables are initialized again. This includes the local variables of accuracy (sum and count)

    • during this Session, the op tf.summary.merge_all() is called at regular intervals. What happens is that your summary is the accuracy of all the batches processed since you last called estimator.train(). Therefore, at the beginning of each training phase, the output is pretty noisy and it gets more stable once you progress.

    • Whenever you evaluate and call estimator.train() again, the local variables are initialized again and you go in a short "noisy" phase, which results in bumps on the training curve.


    A solution

    If you want a scalar summary that gives you the actual accuracy for each batch, it seems like you need to implement it without using tf.metrics. For instance, if you want the accuracy you will need to do:

    accuracy = tf.reduce_mean(tf.cast(tf.equal(labels, predictions), tf.float32))
    tf.summary.scalar('accuracy', accuracy)
    

    It is easy to implement this for the accuracy, and I know it might be painful to do for AUC but I don't see a better solution for now.

    Maybe having these bumps is not so bad. For instance if you train on one epoch, you will get the overall training accuracy on one epoch at the end.