Search code examples
tensorflowtensorflow-estimator

Can I get the tensorflow session from the estimator?


I'm using the LinearRegressor of tf.estimator and want to change my learning rate decay (originally exponential decay) to the decay that use the loss. But to do this, I need to pass the evaluation loss to some placeholders of learning rate decay tensor, and during this step, I need tf.session.

I tried tf.get_default_session() to get the session that is made by estimator, but this session has different graph that used by estimator.


    def my_decay(learning_rate, global_step, decay_step, loss, decay_rate):
      # If loss is not reduced, than decay with decay_rate.

    loss = tf.placeholder(tf.float32)
    estimator = tf.estimator.LinearRegressor(
    feature_columns=feature_columns,
    optimizer==lambda: tf.train.FtrlOptimizer(
        learning_rate=my_decay(learning_rate=0.1,
        global_step=tf.get_global_step(), decay_step=10000,
        loss=loss, decay_rate=0.96)),
      config=sess_config
    )

    for _ in range(n_epoches):
      metrics = tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
      session.run(loss.assign(metrics['loss']))

With above code, I need to get session from the estimator. Is there any way to get this?

Thank you in advance!


Solution

  • The intended solution for something like this is to subclass tf.train.SessionRunHook and override the before_run method to return a suitable tf.train.SessionRunArgs. This will allow you to feed values at train time and add fetches to the session.run call. Your class will have to carry a reference to the placeholder and the loss state in-between the calls.

    Then you simply instantiate the class and add the hook to the hooks parameter in your estimator.train call or in this case your train_spec. If you wish to use the evaluation loss instead of the training loss then this can be achieved by adding another hook to the eval_spec that reads out the values in the after_run method.