Search code examples
pythontensorflowdeep-learningcomputer-visionartificial-intelligence

Inconsistent Model Predictions When Using Entire Validation Dataset vs. Batch Sampling in TensorFlow


I am training a deep learning model using TensorFlow and Keras on an image classification task. My model achieves high validation accuracy when evaluated using the validation_ds dataset. However, when I manually sample a batch from the validation dataset and make predictions, the results are significantly different and much worse.

Here is how I am creating and using the validation dataset:

validation_ds = tf.keras.preprocessing.image_dataset_from_directory( 
  data_directory1, 
  validation_split=0.2,
  subset="validation",
  seed=123, 
  image_size=(img_height, img_width),
  batch_size=batch_size, )

When I use the entire validation dataset for predictions, I do the following:

# predict labels
y_pred = functional_model.predict(validation_ds)
y_pred_classes = np.argmax(y_pred, axis=1)

# true labels
y_true = np.concatenate([y.numpy() for x, y in validation_ds], axis=0)
y_true_classes = np.argmax(y_true, axis=1)

# Generate classification report
report = classification_report(y_true_classes, y_pred_classes, target_names=class_names)
print("Classification Report:")
print(report)

This approach shows much worse performance. However, when I manually sample a batch and make predictions like this:

y_true = []
y_pred = []

for images, labels in validation_ds:
    y_true.extend(np.argmax(labels.numpy(), axis=1))
    predictions = functional_model.predict(images)
    y_pred.extend(np.argmax(predictions, axis=1))

y_true = np.array(y_true)
y_pred = np.array(y_pred)

# Generate classification report
report = classification_report(y_true, y_pred, target_names=class_names)
print("Classification Report:")

The classification report shows reasonable results.

Why are the predictions inconsistent between using the entire validation dataset and a sampled batch?

How can I ensure consistent and accurate predictions for the validation dataset?


Solution

  • You have an implicit shuffle=True in your tf.keras.preprocessing.image_dataset_from_directory() call. Your validation dataset is shuffled each time you iterate over it.

    In your first example, you run inference and you define your y_true in two separate steps, they get shuffled separately and don't match anymore.

    In the second example, you define your predictions and ground truth at the same time, they are shuffled in the same way, and all is well.