Search code examples
pythontensorflowmachine-learningkerasdeep-learning

calling Keras Model.evaluate() on every batch element separately


I would like to call tf.keras.Model.evaluate() (or a similar method) on a batch of my test_data, and I would like to get back the losses/metrics separately for every batch element. So if the batches are 64 element long I would like back a list of 64 losses/metrics.

I need this in order to find outliers in the test dataset.

I tried calling test_on_batch(), or evaluate() on single batches, but this method aggregates the batch result (I assume via mean), and batching every element singularly, although possible, takes 10-20x the time on my GPU.

Also I tried to call predict() and calculate manually the losses/metrics, but this approach also suffers from a steep drop of performance (from the subsequent required manual step of calculating every loss/metric from the test dataset and the predictions)

Is there a way to do this without compromising performance?


Solution

  • Using the TensorFlow metric/loss function with model.predict() is fast and doesn't involve loops

    considering this dummy classification task:

    X = np.random.uniform(0,1, (64,28,28,1))
    y = np.random.randint(0,2, 64)
    
    model = Sequential([Flatten(), Dense(2, activation='softmax')])
    model.compile('adam', 
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(), 
                  metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
    model.fit(X,y, epochs=3)
    

    you can evaluate the score for every batch element in this way:

    scce = tf.keras.losses.sparse_categorical_crossentropy(y, model.predict(X))
    # scce.shape ==> (64,)
    
    scca = tf.keras.metrics.sparse_categorical_accuracy(y, model.predict(X))
    # scca.shape ==> (64,)
    

    these scores are the same scores aggregated by model.evaluate()

    scce_eval, scca_eval = model.evaluate(X,y, verbose=0)
    

    scce_eval is equal to tf.reduce_mean(scce)

    scca_eval is equal to tf.reduce_mean(scca)