Does anyone know, how can I get the global step count in estimator definition?
I need it to adjust the learning rate, when the optimizer created.
like the example below:
def estimator_fn(features, labels, mode):
if mode == tf.estimator.ModeKeys.TRAIN:
optimizer=xxx(learning_rate=GLOBAL_STEP*some process)
And does tf.train.get_global_step
work?
You can also use this:
if mode == tf.estimator.ModeKeys.TRAIN:
global_step = tf.train.get_or_create_global_step()
learning_rate = learning_rate_fn(global_step)
where learning_rate_fn is a function that you can modify your learning rate.
For more info look at here.