Search code examples
pythonmachine-learningkerasneural-networkmulticlass-classification

Multi-class classification Confusion Matrix as Metric in Keras Neural Network


I have a Keras neural network with 4 output neurons (0, 1, 2, 3).

model = models.Sequential()
model.add(Dense(6000, activation='relu', input_shape=(4547,)))
model.add(Dense(3000, activation='relu'))
model.add(Dense(1000, activation='relu'))
model.add(Dense(4, activation='softmax'))

model.compile(optimizer='adam',  loss='sparse_categorical_crossentropy', 
              metrics = [metrics.CategoricalAccuracy()])

The output is a combination of two binary labels with the same input features, so I just wanted to see how the multiclassification performs compared to the two binary classificiations.

When using Keras for binary classification you can use the metrics TP, FP, TN and FN and then have all the values of the confusion matrix epoch-wise, as well as for the final evaluation.

I was wondering if you can do something similiar with multi-class classification. I tried to use a custom built metric, but the values that I'm getting are not plausible.

def true_pos(y_true, y_pred):
    y_pred = K.argmax(y_pred, axis = 1)
    return tf.math.reduce_sum(tf.cast(tf.math.logical_and(tf.math.equal(y_pred, 0),
                              tf.math.equal(tf.cast(y_pred, tf.float32), y_true)), tf.int32))

I also tried using tf.math.confusion_matrix and then selecting certain values in the array to then return the values of the confusion matrix, but those results were also not plausible.

def conf_matrix(y_true, y_pred):
    y_pred = K.argmax(y_pred, axis = 1)
    cm = tf.math.confusion_matrix(y_true, y_pred)
    return cm[0, 0]

In both cases when I test the function in isolation (using my testing data: y_test and model.predict(X_test)) they do what they are supposed to, but once I train the model the values coming out of these functions are not plausible.

When I use sklearn confusion matrix and model.predict() I can return a confusion matrix at the end of the training process, but I would like to track the values for each epoch.

I was also wondering more generally why the confusion matrix (or at least the values its composed of) are not a basic feature in Keras for multi-class classification, since you can easily calculate a range of other metrics from them.


Solution

  • I have a solution now, in which I adapted the custom build metric from a Keras documentation example (https://keras.io/guides/training_with_built_in_methods/ - Custom metrics)

    @keras.saving.register_keras_serializable()
    class MultiClassConfusionMatrix(keras.metrics.Metric):
        def __init__(self, p, t, **kwargs):
            super().__init__(name='P'+str(p)+'T'+str(t), **kwargs)
            self.cm_value = self.add_weight(name="cm", initializer="zeros")
            self.pred = p
            self.true = t
    
        def update_state(self, y_true, y_pred, sample_weight=None):
            y_pred = tf.reshape(tf.argmax(y_pred, axis=1), shape=(-1, 1))
            values = (tf.cast(y_true, "int32") == self.true) & (tf.cast(y_pred, "int32") == self.pred)
            values = tf.cast(values, "float32")
            if sample_weight is not None:
                sample_weight = tf.cast(sample_weight, "float32")
                values = tf.multiply(values, sample_weight)
            self.cm_value.assign_add(tf.reduce_sum(values))
    
        def result(self):
            return self.cm_value
    
        def reset_state(self):
            self.cm_value.assign(0.0)
    

    ...

    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy',
                  metrics=[MultiClassConfusionMatrix(p=0,t=0)])
    

    I compared the results for this metric with the confusion_matrix function from sklearn and the values are the same, so it works as intended. With this function each value (the combination of predicted and true value) of the multi-class confusion matrix can be logged per epoch. The corresponding metrics can then be calculated and compared later on (e.g. accuracy for all true positive cases vs. accuracy for each individual class prediction).

    I assume my first function returned the values per batch and these returns were then averaged over the span of the epoch. Instead, as I understand it, in this case the MultiClassConfusionMatrix object is instantiated at the beginning of training, a variable for cm_value is declared and subsequent classifications are added to it via assign_add(). I would be happy if someone could tell me if this assumption is correct.