Search code examples
pythontensorflowkerastensorflow-datasets

How can I retrieve the first N items from a TensorFlow batch dataset, and not an iterator that reevaluates to different items?


I would like to retrieve the first N items from a BatchDataSet. I have tried a number of different ways to do this, and they all retrieve different items when reevaluated. However I would like to retrieve N actual items, not an iterator that will continue to retrieve new items.

import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
import matplotlib.pyplot as plt
ds = tf.keras.utils.image_dataset_from_directory(
    "Images", 
    validation_split=0.2,
    seed=123,
    subset="training")

# Attempt to retrieve 9 items
test_ds = ds.take(9)

# Plot the 9 items and their labels
plt.figure(figsize=(4, 4))
for images, labels in test_ds:
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.title(class_names[labels[i]])
    plt.axis("off")

#
# AGAIN, plot the 9 items and their labels
# NOTE: This will show 9 different images, and my expectation is 
# that it should show the same images as above.
# 
plt.figure(figsize=(4, 4))
for images, labels in test_ds:
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.title(class_names[labels[i]])
    plt.axis("off")


Solution

  • Iterating over a tf.data.Dataset will trigger shuffling every time. You could set shuffle to False to get deterministic results:

    import tensorflow as tf
    import pathlib
    import matplotlib.pyplot as plt
    
    dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
    data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
    data_dir = pathlib.Path(data_dir)
    
    ds = tf.keras.utils.image_dataset_from_directory(
      data_dir,
      validation_split=0.2,
      subset="training",
      seed=123,
      image_size=(64, 64),
      batch_size=1,
      shuffle=False)
    
    # Attempt to retrieve 9 items
    test_ds = ds.take(9)
    
    class_names = ['a', 'b', 'c', 'd', 'e']
    # Plot the 9 items and their labels
    plt.figure(figsize=(4, 4))
    for i, (images, labels) in enumerate(test_ds):
      ax = plt.subplot(3, 3, i + 1)
      plt.imshow(images[0, ...].numpy().astype("uint8"))
      plt.title(class_names[labels.numpy()[0]])
      plt.axis("off")
    
    plt.figure(figsize=(4, 4))
    for i, (images, labels) in enumerate(test_ds):
      ax = plt.subplot(3, 3, i + 1)
      plt.imshow(images[0, ...].numpy().astype("uint8"))
      plt.title(class_names[labels.numpy()[0]])
      plt.axis("off")
    

    enter image description here enter image description here

    If you are interested in other data samples, you can just use the methods tf.data.Dataset.skip and tf.data.Dataset.take.