Search code examples

Log accuracy metric while training a tf.estimator

What's the simplest way to print accuracy metrics along with the loss when training a pre-canned estimator?

Most tutorials and documentations seem to address the issue of when you're creating a custom estimator -- which seems overkill if the intention is to use one of the available ones.

tf.contrib.learn had a few (now deprecated) Monitor hooks. TF now suggests using the hook API, but it appears that it doesn't actually come with anything that can utilize the labels and predictions to generate an accuracy number.


  • Have you tried tf.contrib.estimator.add_metrics(estimator, metric_fn) (doc)? It takes an initialized estimator (can be pre-canned) and adds to it the metrics defined by metric_fn.

    Usage Example:

    def custom_metric(labels, predictions):
        # This function will be called by the Estimator, passing its predictions.
        # Let's suppose you want to add the "mean" metric...
        # Accessing the class predictions (careful, the key name may change from one canned Estimator to another)
        predicted_classes = predictions["class_ids"]  
        # Defining the metric (value and update tensors):
        custom_metric = tf.metrics.mean(labels, predicted_classes, name="custom_metric")
        # Returning as a dict:
        return {"custom_metric": custom_metric}
    # Initializing your canned Estimator:
    classifier = tf.estimator.DNNClassifier(feature_columns=columns_feat, hidden_units=[10, 10], n_classes=NUM_CLASSES)
    # Adding your custom metrics:
    classifier = tf.contrib.estimator.add_metrics(classifier, custom_metric)
    # Training/Evaluating:
    tf.logging.set_verbosity(tf.logging.INFO) # Just to have some logs to display for demonstration
    train_spec = tf.estimator.TrainSpec(input_fn=lambda:your_train_dataset_function(),
    tf.estimator.train_and_evaluate(classifier, train_spec, eval_spec)


    INFO:tensorflow:Running local_init_op.
    INFO:tensorflow:Done running local_init_op.
    INFO:tensorflow:Evaluation [20/200]
    INFO:tensorflow:Evaluation [40/200]
    INFO:tensorflow:Evaluation [200/200]
    INFO:tensorflow:Finished evaluation at 2018-04-19-09:23:03
    INFO:tensorflow:Saving dict for global step 1: accuracy = 0.5668, average_loss = 0.951766, custom_metric = 1.2442, global_step = 1, loss = 95.1766

    As you can see, the custom_metric is returned along the default metrics and loss.