Search code examples
pythontensorflowimage-processingconv-neural-network

How can I add my custom method for image preprocessing task in model.Sequential() of Tensorflow?


I have implemented a code to perform image preprocessing tasks. And it is a fully functioning method but it works with one image at a time. But I don't know how to make the function compatible with the tf dataset as I used OpenCV to perform the image processing task. And, how can I add this layer to models.Sequential() to match with the Sequential() requirements type? The code where I want to add my custom function is -

model = models.Sequential([
    resize_and_rescale,
    data_augmentation,
    custom_function_want_to_add_here(),
    layers.Conv2D(32, (3,3,),activation = 'relu', input_shape = input_shape),
    layers.MaxPooling2D((2,2)),
    layers.Conv2D(64, kernel_size = (3,3,),activation = 'relu'),
    layers.MaxPooling2D((2,2)),
    layers.Conv2D(64, kernel_size = (3,3,),activation = 'relu'),
    layers.MaxPooling2D((2,2)),
    layers.Conv2D(64, (3,3,),activation = 'relu'),
    layers.MaxPooling2D((2,2)),
    layers.Conv2D(64, (3,3,),activation = 'relu'),
    layers.MaxPooling2D((2,2)),
    layers.Conv2D(64, (3,3,),activation = 'relu'),
    layers.MaxPooling2D((2,2)),
    layers.Flatten(),
    layers.Dense(64, activation = 'relu'),
    layers.Dense(n_classes,activation='softmax'),
])

Solution

  • Two options come to mind:

    1. put your custom function to a custom TensorFlow layer

      • the convenience of this method is that all the logic is kept inside the model, however, implementation could be challenging since it would require the layer logic to work with TensorFlow backend
    2. put the image preprocessing outside of the model

      • since the preprocessing does not need any training, you can put it outside the model (so that the first layer of the model would be layers.Conv2D(32, ...))

    Since your custom logic requires calling a function from the OpenCV library (if I understood correctly) I would probably recommend the second approach. It will be easier to make the function work with the tf.data.Dataset than it would be to make it work inside the model.