Search code examples
tensorflowbackbackpropagationgradient-descenttensorflow-estimator

In tensorflow estimator class, what does it mean to train one step?


Specifically, within one step, how does it training the model? What is the quitting condition for the gradient descent and back propagation?

Docs here: https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator#train

e.g.

  mnist_classifier = tf.estimator.Estimator(model_fn=cnn_model_fn)

  train_input_fn = tf.estimator.inputs.numpy_input_fn(
      x={"x": X_train},
      y=y_train,
      batch_size=50,
      num_epochs=None,
      shuffle=True)

  mnist_classifier.train(
      input_fn=train_input_fn,
      steps=100,
      hooks=[logging_hook])

I understand that training one step means that we feed the neural network model with batch_size many data points once. My questions is, within this one step, how many times does it perform gradient descent? Does it do back propagation and gradient descent just once or does it keep performing gradient descent until the model weights reach a optimal for this batch of data?


Solution

  • 1 step = 1 gradient update. And each gradient update step requires one forward pass and one backward pass.

    The stopping condition is generally left up to you and is arguably more art than science. Commonly you will plot (tensorboard is handy here) your cost, training accuracy, and periodically your validation set accuracy. The low point on validation accuracy is generally a good point to stop. Depending on your dataset validation accuracy may drop and at some point increase again, or it may simply flatten out, at which point the stopping condition often correlates with the developer's degree of impatience.

    Here's a nice article on stopping conditions, a google search will turn up plenty more.

    https://stats.stackexchange.com/questions/231061/how-to-use-early-stopping-properly-for-training-deep-neural-network

    Another common approach to stopping is to drop the learning rate every time you compute that no change has occurred to validation accuracy for some "reasonable" number of steps. When you've effectively hit 0 learning rate, you call it quits.