Search code examples
tensorflowdeep-learningkeraskeras-layer

regarding putting multiple patches for a single image into a single mini-batch


Regarding patch-wise training for image classification or segmentation, I need to put multiple patches corresponding to a single image into a single mini-batch during training process. How to do that in Keras? Or how can I ensure multiple training patches in a single mini-batch belong to the same training image?


Solution

  • I suggest you implement your own generator for this. This doesn't need to be complicated. Your code will be something like this

    class PatchGenerator():
        def __init__(self, batch_size, X, y):
            self.batch_size = batch_size
            # self.X is a list of input images
            self.X = X
            # self.y is a list of target classes
            self.y = y
            self.index = 0
    
        def __iter__(self):
            return self
    
        def next(self):
            # Get next image
            image = self.X[self.index]
            target = self.Y[self.target]
            self.index += 1
            if self.index > len(self.X):
                self.index = 0
    
            batch = []
            for i in range(self.batch_size):
                # Generate a new random patch for the image
                patch = get_random_patch(image) # Implement this yourself
                batch.append((patch, target))
            return np.array(batch)
    
    # Create the new generator
    patch_generator = PathGenerator(32, X, y)
    
    # Fit your model with the generator
    model.fit_generator(patch_generator, samples_per_epoch=len(X))
    

    The PatchGenerator class above will ensure that each batch only contains patches from the same input image. It will hopefully give you an idea of how you can implement this.

    Take a look at the source code of keras.preprocessing for different functions you can use for generating the patches (https://github.com/fchollet/keras/blob/master/keras/preprocessing/image.py).

    Also, read this if you need to learn more about Python generators https://wiki.python.org/moin/Generators.