Search code examples
tensorflowkerasmetricstraining-datatf.keras

Should a custom keras true positive metric always return an integer?


I'm working with a non-standard dataset, where my y_true is (batch x 5 x 1), and y_pred is (batch x 5 x 1). A batch sample i is "true" if any value of y_true[i] > 0., and it is predicted "true" if an y_pred[i] >= b where b is a threshold between 0 and 1.

I've defined this custom keras metric to calculate the number of true positives in a batch:

def TP(threshold=0.0):

    def TP_(Y_true, Y_pred):

        Y_true = tf.where(Y_true > 0., tf.ones(tf.shape(Y_true)), tf.zeros(tf.shape(Y_true)))
        Y_pred_true = tf.where(Y_pred >= threshold, tf.ones(tf.shape(Y_pred)), tf.zeros(tf.shape(Y_pred)))

        Y_true = K.sum(Y_true, axis=1)
        Y_pred_true = K.sum(Y_pred_true, axis=1)

        Y_true = tf.where(Y_true > 0., tf.ones(tf.shape(Y_true)), tf.zeros(tf.shape(Y_true)))
        Y_pred_true = tf.where(Y_pred_true > 0., tf.ones(tf.shape(Y_pred_true)), tf.zeros(tf.shape(Y_pred_true)))

        Y = tf.math.add(Y_true, Y_pred_true)
        tp = tf.where(Y == 2, tf.ones(tf.shape(Y)), tf.zeros(tf.shape(Y)))
        tp = K.sum(tp)

        return tp

    return TP_

When training, I sometimes get non-integer values. Is this because keras is averaging the values from all batches?

I have similar custom metrics for true negatives, false positives, and false negatives. Should the sum of all four of these values during training be an integer?


Solution

  • A two part answer: Yes, the metrics are averaged over the batches. You will see the same behavior with the built-in metrics, eg tensorflow.keras.metrics.TruePositive, but at the end of each epoch it will be an integer.

    However, you are not persisting state for your metric, so TensorFlow just takes the mean of your returned metric. Consider subclassing tf.keras.metrics.Metric like so:

    class TP(tf.keras.metrics.Metric):
        
        def __init__(self, threshold=0.5, **kwargs):
            super().__init__(**kwargs)
    
            self.threshold = threshold
            self.true_positives = self.add_weight(name='true_positives', initializer='zeros',
                                                  dtype=tf.int32)
    
        def update_state(self, y_true, y_pred, sample_weight=None):
    
            y_true = tf.where(y_true > self.threshold,
                              tf.ones(tf.shape(y_true)),
                              tf.zeros(tf.shape(y_true)))
            y_pred_true = tf.where(y_pred >= self.threshold,
                                   tf.ones(tf.shape(y_pred)),
                                   tf.zeros(tf.shape(y_pred)))
    
            y_true = K.sum(y_true, axis=1)
            y_pred_true = K.sum(y_pred_true, axis=1)
    
            y_true = tf.where(y_true > self.threshold, tf.ones(tf.shape(y_true)),
                              tf.zeros(tf.shape(y_true)))
            y_pred_true = tf.where(y_pred_true > self.threshold,
                                   tf.ones(tf.shape(y_pred_true)),
                                   tf.zeros(tf.shape(y_pred_true)))
    
            Y = tf.math.add(y_true, y_pred_true)
            tp = tf.where(Y == 2, tf.ones(tf.shape(Y), dtype=tf.int32),
                          tf.zeros(tf.shape(Y), dtype=tf.int32))
            tp = K.sum(tp)
            self.true_positives.assign_add(tp)
    
        def result(self):
            return self.true_positives
    
        def get_config(self):
            return {'threshold': self.threshold}