Search code examples
pythontensorflowtensorboard

Tensorflow plot tf.metrics.precision_at_thresholds in Tensorboard through eval_metric_ops


tf.metrics.precision_at_thresholds() takes three arguments: labels, predictions, thresholds where thresholds is a a python list or tuple of thresholds between [0,1]. The function then returns "A float Tensor of shape [len(thresholds)]" which is problematic for automatically plotting eval_metric_ops to tensorboard (as I believe they are expected a scalar). The values will print to the console just fine, but I would also like to plot the values in tensorboard. Is there any adjustment that can be made to be able to plot the value in tensorboard?


Solution

  • I found it really strange that TensorFlow (as of 1.8) does not offer a summary function for metrics like tf.metrics.precision_at_thresholds (in general tf.metrics.*_at_thresholds). The following is a minimal working example:

    def summarize_metrics(metrics_update_ops):
        for metric_op in metric_ops:
            shape = metric_op.shape.as_list()
            if shape:  # this is a metric created with any of tf.metrics.*_at_thresholds
                summary_components = tf.split(metric_op, shape[0])
                for i, summary_component in enumerate(summary_components):
                    tf.summary.scalar(
                        name='{op_name}_{i}'.format(op_name=summary_components.name, i=i),
                        tensor=tf.squeeze(summary_component, axis=[0])
                    )
            else:  # this already is a scalar metric operator
                tf.summary.scalar(name=summary_components.name, tensor=metric_op)
    
    precision, precision_op = tf.metrics.precision_at_thresholds(labels=labels,
                                                                 predictions=predictions,
                                                                 thresholds=threshold)
    summarize_metrics([precision_op])
    

    The downside of this approach, in general, is that notion of whatever thresholds you used to create the metric in the first place, is lost when summarizing them. I came up with a slightly more complex, but easier to use solution that uses collections to store all metric update operators.

    # Create a metric and let it add the vars and update operators to the specified collections
    thresholds = [0.5, 0.7]
    tf.metrics.recall_at_thresholds(
        labels=labels, predictions=predictions, thresholds=thresholds,
        metrics_collections='metrics_vars', metrics_update_ops='metrics_update_ops'
    )
    
    # Anywhere else call the summary method I provide in the Gist at the bottom [1]
    # Because we provide a mapping of a scope pattern to the thresholds, we can
    # assign them later
    summarize_metrics(list_lookup={'recall_at_thresholds': thresholds})
    

    The implementation in the Gist [1] below also supports options for formatting the sometimes cryptic names of the metrics nicely.

    [1]: https://gist.github.com/patzm/961dcdcafbf3c253a056807c56604628

    How this could look like: Imgur