Search code examples
pythontensorflowkerastensorflow-datasets

Calling Keras standard model preprocessing functions in TF Dataset pipeline


I am using some of the standard CNN models shipped with Keras as base for my own models - let's say a VGG16. Until now I am used to call the respective preprocessing functions via the Keras image data generators, like so:

ImageDataGenerator(preprocessing_function=vgg16.preprocess_input)  # or any other std. model

Now I want to use a TF Dataset instead, so that I can use its from_tensor_slices() method, which makes multi GPU training easier. I came up with the following custom preprocessing function for this new pipeline:

@tf.function
def load_images(image_path, label):
    image = tf.io.read_file(image_path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = vgg16.preprocess_input(image)  # Is this call correct?
    image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
    return (image, label)

But I am not sure whether this is the correct order of function calls, as well as the correct place of calling vgg16.preprocess_input(image) within this order. Can I call this std. preprocessing function as this, or do I need to convert image data before/after that?


Solution

  • You could create a dataset from_tensor_slices() with your paths and labels and then use map to load and preprocess the images:

    import tensorflow as tf
    import matplotlib.pyplot as plt
    import numpy
    from PIL import Image
    
    # Create random images
    for i in range(3):
      imarray = numpy.random.rand(100,100,3) * 255
      im = Image.fromarray(imarray.astype('uint8'))
      im.save('result_image{}.jpeg'.format(i))
    
    def load_images(image_path, label):
        image = tf.io.read_file(image_path)
        image = tf.image.decode_jpeg(image, channels=3)
        
        #preprocess_input --> will convert the input images from RGB to BGR, then will zero-center each color channel with respect to the ImageNet dataset, without scaling
        image = tf.keras.applications.vgg16.preprocess_input(image)
        image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
        image /= 255.0 
        return image, label
    
    IMG_SIZE = 64
    paths = ['result_image0.jpeg', 'result_image1.jpeg', 'result_image2.jpeg']
    labels = [0, 1, 1]
    
    dataset = tf.data.Dataset.from_tensor_slices((paths, labels))
    ds = dataset.map(load_images)
    
    image, _ = next(iter(ds.take(1)))
    plt.imshow(image)
    

    enter image description here

    Or you can use tf.keras.applications.vgg16.preprocess_input as part of your model. For example:

    preprocess = tf.keras.applications.vgg16.preprocess_input
    
    some_input = tf.keras.layers.Input((256, 256, 3))
    some_output = tf.keras.layers.Lambda(preprocess)(some_input)
    model = tf.keras.Model(some_input, some_output)
    
    model(tf.random.normal((2, 256, 256, 3)))