Search code examples
pythontensorflowmachine-learningkeras

Tensorflow Model Predictions with dataset are just guessing, when evaluations are good


I am training an image classifier model for detecting rodents. When I fit the model and evaluate it, I have strong metrics that indicate it's performing well. However when I do predictions with tf datasets, it seems the model is just guessing. When I predict singular images it is accurate again.

here is the function that does all the evaluating:

def evaluate_model(hist, model_fp, model, train, test):
    # plot basic metrics
    metrics_path, _ = os.path.splitext(model_fp)
    train_labels = np.concatenate([y for x, y in train], axis=0)
    test_labels = np.concatenate([y for x, y in test], axis=0)

    plt.figure()
    ... # plot accuracy metrics from history

    metrics = ['loss', 'prc', 'precision', 'recall']
    ... # plot more metrics from history

    results = model.evaluate(test, verbose=0)
    for name, value in zip(model.metrics_names, results):
        print(name, ': ', value)
    print('\n')

    results = model.evaluate(train, verbose=0)
    for name, value in zip(model.metrics_names, results):
        print(name, ': ', value)
    print('\n')

    train_predictions = model.predict(train)
    test_predictions = model.predict(test)

    # PRC, ROC, CM GENERATED FROM THESE PREDICTIONS

Here is the function that grabs and batches the images

def prepare_images(dir):
    training_data = image_dataset_from_directory(
        dir,
        validation_split=0.2,
        subset="training",
        image_size=IMG_SHAPE,
        batch_size=BATCH_SIZE,
        shuffle=True,
        seed=123
    )

    validation_data = image_dataset_from_directory(
        dir,
        validation_split=0.2,
        subset="validation",
        image_size=IMG_SHAPE,
        batch_size=BATCH_SIZE,
        shuffle=True,
        seed=456
    )

    # create test data set using 20% of validation data
    val_batches = tf.data.experimental.cardinality(validation_data)
    test_data = validation_data.take(val_batches // 5)
    validation_data = validation_data.skip(val_batches // 5)
    return training_data, validation_data, test_data

fitting the model

def train_model(model, train, val, test, epochs, cb, class_weights, model_fp):
    history = model.fit(
        train,
        validation_data=val,
        epochs=epochs,
        batch_size=BATCH_SIZE,
        callbacks=cb,
        class_weight=class_weights,
        verbose=1
    )

So when I evaluate the model with the test data set the output is as follows:

loss :  0.1926884800195694
tp :  197.0
fp :  20.0
tn :  159.0
fn :  8.0
accuracy :  0.9270833134651184
precision :  0.9078341126441956
recall :  0.9609755873680115
auc :  0.979574978351593
prc :  0.9817492961883545

however when predicting with the same batch...

58/58 [==============================] - 6s 91ms/step
3/3 [==============================] - 1s 142ms/step
Legitimate Transactions Detected (True Negatives):  98
Legitimate Transactions Incorrectly Detected (False Positives):  92
Fraudulent Transactions Missed (False Negatives):  83
Fraudulent Transactions Detected (True Positives):  111
Total Fraudulent Transactions:  194

And when individually predicting 10 images, 9 of them are correct.

So why is model.predict seemingly randomly guessing when it comes to using a batch dataset, but not with individual images? Is there a work around? Thanks


Solution

  • Ok so I figured it out, not sure if this is an error in documentation or a bug in Tensorflow, because the docs clearly state you can use a tf.data dataset for x input, however when I do my model is just randomly guessing. I even made sure to not specify batch size as it says. The work around is to split up the data set into image data and label data as follows:

    for image_batch, labels_batch in test:
        x_test = image_batch.numpy()
        y_test = labels_batch.numpy()
    

    and then call the predict function with just x_test, and plot metrics with those predictions, and y_test

    test_predictions = model.predict(x_test)
    plot_cm(test_predictions, y_test, p, metrics_path) # function to plot confusion matrix