Search code examples
pythontensorflowmetricstop-n

TensorFlow metric: top-N accuracy


I'm using add_metric trying to create a custom metric that computes top 3 accuracy for a classifier. Here's as far as I got:

def custom_metrics(labels, predictions):
   # labels => Tensor("fifo_queue_DequeueUpTo:422", shape=(?,), dtype=int64)
   # predictions => {
   #    'logits': <tf.Tensor 'dnn/logits/BiasAdd:0' shape=(?, 26) dtype=float32>,
   #     'probabilities': <tf.Tensor 'dnn/head/predictions/probabilities:0' shape=(?, 26) dtype=float32>,
   #     'class_ids': <tf.Tensor 'dnn/head/predictions/ExpandDims:0' shape=(?, 1) dtype=int64>,
   #     'classes': <tf.Tensor 'dnn/head/predictions/str_classes:0' shape=(?, 1) dtype=string>
   #  }

Looking at the implementation of existing tf.metrics, everything is implemented using tf ops. How could I implement top 3 accuracy?


Solution

  • If you want to implement it yourself tf.nn.in_top_k is very useful - it returns a boolean array which indicates if target is within the top k predictions. You just have to take the mean of the result:

    def custom_metrics(labels, predictions):
        return tf.metrics.mean(tf.nn.in_top_k(predictions=predictions, targets=labels, k=3))
    

    You can also import it:

    from tf.keras.metrics import top_k_categorical_accuracy