Search code examples
tensorflowtensorflow-estimator

How can I get global step in tensorflow.estimator?


Does anyone know, how can I get the global step count in estimator definition?
I need it to adjust the learning rate, when the optimizer created.
like the example below:

def estimator_fn(features, labels, mode):
  if mode == tf.estimator.ModeKeys.TRAIN:
    optimizer=xxx(learning_rate=GLOBAL_STEP*some process)

And does tf.train.get_global_step work?


Solution

  • You can also use this:

    if mode == tf.estimator.ModeKeys.TRAIN:
       global_step = tf.train.get_or_create_global_step()
       learning_rate = learning_rate_fn(global_step)
    

    where learning_rate_fn is a function that you can modify your learning rate.

    For more info look at here.