Search code examples

"Cannot evaluate tensor using `eval()`: No default session is registered" while using custom SessionRunHook with Estimator API

I'm following this example in order to learn how to build a TensorFlow's CNN using Estimator API. In the given example there is a line pred_probas = tf.nn.softmax(logits_test) which would be highly valuable for me if I could obtain those probabilities since I'd like to use them in this small code snippet I wrote:

def eer_eval(y_true, probas):
    fpr, tpr, thresholds = roc_curve(y_true.eval(), probas[:, 1].eval())
    return brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)

After reading this post I wrote my own hook

class _EERHook(tf.train.SessionRunHook):
    def __init__(self, probas, labels):
        self.labels = labels
        self.probas = probas

    def begin(self):

    def before_run(self, run_context):
        return tf.train.SessionRunArgs(eer_eval(self.labels, self.probas))

    def after_run(self,
                run_context,  # pylint: disable=unused-argument
        eer = run_values.results
        print("EER: ", eer)

which I'd like to use during the evaluation of the model

estim_specs = tf.estimator.EstimatorSpec(
        eval_metric_ops={'accuracy': acc_op},
        evaluation_hooks=[_EERHook(pred_probas, labels)])

However, the code crashes with the error

ValueError: Cannot evaluate tensor using `eval()`: No default session is registered. Use `with sess.as_default()` or pass an explicit session to `eval(session=sess)`

Is there any way I could save those probabilities to a human-readable csv file during evaluation or make my code snippet work?


  • This function eer_eval(y_true, probas) is obviously not a tensorflow style. So maybe it's better to let the hook compute y_true and probas, and give the numpy values to eer_eval()?

    In _EERHook:

    def before_run(self, run_context):
        return tf.train.SessionRunArgs((self.labels, self.probas))
    def after_run(self,
                run_context,  # pylint: disable=unused-argument
        results = run_values.results
        print('labels:', results[0])
        print('probas:', results[1])
        # err_eval(results[0], results[1])