Search code examples
pythonimagetensorflowkeras

Multi Scale Segmentation mask outputs in keras in U Net


So this is the model, with input as a single image and outputs at different scales of the image, i.e., I, 1/2 I, 1/4 I and 1/8 I, Model(inputs=[inputs], outputs=[out6, out7, out8, out9])

I am not sure how to create the train dataset. Suppose the input to the y_train will be data of say shape (50, 192, 256, 3) where 3 = channel of the image, 192 is the width and 256 is the height, and there are 50 of them, but how to create a y_train which will have 4 components? I have tried with zip and then converting it to numpy but that doesn't works...


Solution

  • If you necessarily want the model to learn to generate multi-scale masks then you can try downsampling to generate the scaled masks for supervised learning using UNET. You can use interpolation-based methods to automatically resize an image with minimum loss. Here is a post where I compare benchmarks against multiple such methods.

    If you want to create [masks, masks_half, masks_quarter, masks_eighth] for your model.fit, which is the list of original + rescaled versions of the mask images, you may wanna try a fast downsampling method (depending on the size of your dataset).

    Here I have used skimage.transform.pyramid_reduce to downsample a mask to half, quarter, and eighth of its scale. The method uses interpolation (spline), but can be controlled via parameters. Check this for more details.

    from skimage.transform import pyramid_reduce
    
    masks = np.random.random((50, 192, 256, 3))
    
    masks_half = np.stack([pyramid_reduce(i, 2, multichannel=True) for i in masks])
    masks_quater = np.stack([pyramid_reduce(i, 4, multichannel=True) for i in masks])
    masks_eighth = np.stack([pyramid_reduce(i, 8, multichannel=True) for i in masks])
    
    print('Shape of original',masks.shape)
    print('Shape of half scaled',masks_half.shape)
    print('Shape of quater scaled',masks_quater.shape)
    print('Shape of eighth scaled',masks_eighth.shape)
    
    Shape of original (50, 192, 256, 3)
    Shape of half scaled (50, 96, 128, 3)
    Shape of quater scaled (50, 48, 64, 3)
    Shape of eighth scaled (50, 24, 32, 3)
    

    Testing on a single image/mask -

    from skimage.data import camera
    from skimage.transform import pyramid_reduce
    
    def plotit(img, h, q, e):
        fig, axes = plt.subplots(1,4, figsize=(10,15))
        axes[0].imshow(img)
        axes[1].imshow(h)
        axes[2].imshow(q)
        axes[3].imshow(e)
        axes[0].title.set_text('Original')
        axes[1].title.set_text('Half')
        axes[2].title.set_text('Quarter')
        axes[3].title.set_text('Eighth')
    
    img = camera() #(512,512)
    h = pyramid_reduce(img, 2)   #Half
    q = pyramid_reduce(img, 4)   #Quarter
    e = pyramid_reduce(img, 8)   #Eighth
    
    plotit(img, h, q, e)
    

    enter image description here

    Notice the change in scale over x and y-axis --------------------->