Search code examples
pythonkerasdeep-learningsemantic-segmentationunet-neural-network

Facing an issue while training UNet for Image Segmentation


I am using the keras_unet_collection python library's U-Net model to perform self supervised learning on images (spectrograms; dim=(376, 128)) to learn background/foreground masks. I am very new to self supervised learning but after reading all the related research papers I have created this approach:

  • Generate spectrograms from audio files (they are the ground truth)
  • Mask random sections of spectrograms (input images) I am trying to train the U-Net model to reconstruct the distorted spectrogram and in the process learn to identify background noise for deniosing the image/audio by classifying each pixel as 0 or 1 (where 0 should be background).

ground truth (y): ground truth spectrogram

input image (X): distorted spectrogram

But I have encountered an error when creating the model.

Given a PIL image of a spectrogram (img.size returns (376, 128)) This is how I've defined the model:

model = models.unet_2d(input_size=(img.size[0], img.size[1], 3), 
                       filter_num=[64, 128, 256, 512, 1024],
                       stack_num_down=2, stack_num_up=1,
                       weights='imagenet', n_labels=1,
                       activation='GELU', output_activation='Softmax',
                       batch_norm=True, pool='max', unpool='nearest', name='unet')

I get this error:

ValueError: A Concatenate layer requires inputs with matching shapes except for the concatenation axis. Received: input_shape=[(None, 46, 16, 512), (None, 47, 16, 512)]

This is the stack trace:

> keras_unet_collection/_model_unet_2d.py:288
X = unet_2d_base(IN, filter_num, stack_num_down=stack_num_down, stack_num_up=stack_num_up, 
                      activation=activation, batch_norm=batch_norm, pool=pool, unpool=unpool, 
                      backbone=backbone, weights=weights, freeze_backbone=freeze_backbone, 
                      freeze_batch_norm=freeze_backbone, name=name)

> keras_unet_collection/_model_unet_2d.py:213
 X = UNET_right(X, [X_decode[i],], filter_num_decode[i], stack_num=stack_num_up, activation=activation, 
                        unpool=unpool, batch_norm=batch_norm, name='{}_up{}'.format(name, i))

> keras_unet_collection/_model_unet_2d.py:86
X = concatenate([X,]+X_list, axis=3, name=name+'_concat')

> keras/src/layers/merging/concatenate.py:172
return Concatenate(axis=axis, **kwargs)(inputs)

I'm really not able to figure this out. Any help/suggestion would be great!


Solution

  • You're likely running into the issue that your input image size must be divisible by 2 ** N, where N is the number of filter layers in your model.

    Each layer on in the contracting path (or left side) of the U-Net divides the number of pixels in each dimension by two, while allowing the number of filters to increase. On the expansion path, each layer doubles the number of pixels.

    However, the number of pixels is always required to be an integer. If you have five pixels, and you divide it by two and round down, you get two pixels. Double that again, and you get four pixels, which doesn't match the five pixel layer anymore.

    Comparing this to your example, you have five layers, so all inputs shapes must be divisible by 2 ** 5. But 376 is not divisible by 32.

    This is a problem, because of the concatenation step. When combining the input from expansion path with the input from the contracting path, the dimensions must be compatible, except for the axis you're concatenating on. So, if you had shapes of (None, 46, 16, 512) and (None, 46, 16, 512), that would work. However, you have a mismatch in dimension 1.

    There are various ways to handle this.

    • Pad, crop or resize the original image to a multiple of 2 ** N. (This is likely the simplest - it doesn't require modifying this library.)
    • Within the U-Net, crop one input image to the size of the smaller image before concatenation.
    • Within the U-Net, add up to 1 row or column of padding at each layer if needed.