Search code examples
pythonkerasmetrics

How can I use keras categorical_accuracy with a multi-dimentional output?


I'm trying to evaluate my keras neural network using the categorical_accuracy metric. However, this only works on 1d lists and so if I have this as my target:

[[[0, 1, 0], [0, 0, 0]], ...]

and this as my output:

[[[0.1, 0.8, 0.2], [0.3, 0.4, 0.2]], ...]

It will return an accuracy of 0.5 for that item as it compares each of the two lists together. I would however like to compare the maximum argument over the entirety of my output (3d output) and therefore would like that example above to return an accuracy of 1.

I've tried this:

class Accuracy(keras.metrics.CategoricalAccuracy):
    def __init__(self):
        super().__init__()

    def call(self, inputs, **kwargs):
        super().call((tf.reshape(inputs[0], (output_size,)), tf.reshape(inputs[1], (output_size,))), **kwargs)

But it still seems to return the same values. Is there any other way to adapt an innate keras function without resorting to creating my own metric, which is likely to run much more slowly?

UPDATE: I have created a function to solve this:

def cat_acc(y_true, y_pred):
    return tf.keras.metrics.categorical_accuracy(tf.keras.backend.flatten(y_true),
                                                 tf.keras.backend.flatten(y_pred))

However this has the problem that it counts the entire batch in one go, and so it will either return 1 or 0, with 1 being when the maximum value in the entire batch is correct, and 0 when it is not. I cannot find a sensible way to loop through the batch without it throwing a value error when keras compiles it when my model is created.


Solution

  • I eventually came up with the following solution:

    def cat_acc(y_true, y_pred):
        return tf.reduce_mean(tf.keras.metrics.categorical_accuracy(tf.reshape(y_true, (-1, output_size)),
                                                                    tf.reshape(y_pred, (-1, output_size))))
    

    Here, output_size refers to the size of the output of the model, e.g for a model with an output of shape (3, 2, 2), the output_size variabe would be 12.