Search code examples
pythontensorflowkerastensorflow-datasets

How to use mobilenet_v2.preprocess_input on tensorflow dataset


I'm again struggling with the usage of tensorflow datasets. I'm again loading my images via

data = keras.preprocessing.image_dataset_from_directory(
  './data', 
  labels='inferred', 
  label_mode='binary', 
  validation_split=0.2, 
  subset="training", 
  image_size=(img_height, img_width), 
  batch_size=sz_batch, 
  crop_to_aspect_ratio=True
)

I want to use this dataset in the pre-trained MobileNetV2

model = keras.applications.mobilenet_v2.MobileNetV2(input_shape=(img_height, img_width, 3), weights='imagenet')

The documentation says, that the input data must be scaled to be between -1 and 1. To do so, the preprocess_input function is provided. When I use this function on my dataset

scaled_data = tf.keras.applications.mobilenet_v2.preprocess_input(data)

I get the error: TypeError: unsupported operand type(s) for /=: 'BatchDataset' and 'float'

So how can I use this function properly with the tensorflow dataset?


Solution

  • Maybe try using tf.data.Dataset.map:

    import tensorflow as tf
    import pathlib
    
    dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
    data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
    data_dir = pathlib.Path(data_dir)
    
    batch_size = 32
    
    train_ds = tf.keras.utils.image_dataset_from_directory(
      data_dir,
      validation_split=0.2,
      subset="training",
      seed=123,
      image_size=(180, 180),
      batch_size=batch_size)
    
    def preprocess(images, labels):
      return tf.keras.applications.mobilenet_v2.preprocess_input(images), labels
    
    train_ds = train_ds.map(preprocess)
    
    images, _ = next(iter(train_ds.take(1)))
    image = images[0]
    plt.imshow(image.numpy())
    

    Before preprocessing the images:

    enter image description here

    After preprocessing the images with tf.keras.applications.mobilenet_v2.preprocess_input only:

    enter image description here

    After preprocessing the images with tf.keras.layers.Rescaling(1./255) only: enter image description here