Search code examples
pythontensorflowkerastensorflow-datasets

How to apply pre-processing to images of a tf.data.Dataset?


If I understand correctly instead of loading a a full dataset into memory like this:

images = []
file_list = glob.glob('path/to/images/*.jpg')
for file in file_list:
    images.append(img_to_array(load_img(file, target_size=input_shape)))

images = np.stack(images, axis=0)
images = preprocess(images)

# classify the image
print("[INFO] classifying image with '{}'...".format(used_model))
predictions = model.predict(images)
decoded_predictions = imagenet_utils.decode_predictions(predictions)

One should use to tensorflows data utilities for better memory management and performance:

images = tf.keras.utils.image_dataset_from_directory(file_path, image_size=input_shape, labels=None)
AUTOTUNE = tf.data.AUTOTUNE
images = images.prefetch(buffer_size=AUTOTUNE)

# this line will now crash
images = preprocess(images)


# classify the image
print("[INFO] classifying image with '{}'...".format(used_model))
predictions = model.predict(images)
decoded_predictions = imagenet_utils.decode_predictions(predictions)

As written in the code above, I know have different data structures, which will not work with the same code. My question is: How can I apply pre-processing to my data? All the corresponding tutorials seem to deal with training, while I want to do simple inference.

Additional Question: How would this be done, if the data is coming from an S3 bucket (with the script running in an Airflow-DAG)?


Solution

  • You can use tf.data.Dataset.map to apply preprocessing to your images or batches of images. Here is an example:

    import tensorflow as tf
    import pathlib
    
    dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
    data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
    data_dir = pathlib.Path(data_dir)
    
    batch_size = 32
    
    train_ds = tf.keras.utils.image_dataset_from_directory(
      data_dir,
      seed=123,
      image_size=(180, 180),
      batch_size=batch_size)
    
    scale_layer = tf.keras.layers.Rescaling(1./255)
    
    def preprocess(images, labels):
      images = tf.image.resize(scale_layer(images),[120, 120], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
      return images, labels
    
    train_ds = train_ds.map(preprocess)
    

    In your case, you just have images so you can ignore the labels here.