Search code examples
pythonnumpydeep-learningsemantic-segmentationunet-neural-network

Why I have this value error when training U-net?


I am a beginner in deep_learning and my graduation thesis is about semantic segmentation. However I received a ValueError when using library "segmentation-models" Docs of the library

Part of my codes are as follows and my error occur at validation_data=(x_val, y_val))

import tensorflow as tf
import segmentation_models as sm
import glob
import cv2
import os
import numpy as np
from matplotlib import pyplot as plt



model = sm.Unet(BACKBONE, encoder_weights='imagenet', encoder_freeze=True)

model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['mse'])



history=model.fit(x_train, 
          y_train,
          batch_size=8, 
          epochs=5000,
          verbose=1,
          validation_data=(x_val, y_val))
          )

Error messages are: enter image description here

I don't understand it (Because of my limited English or lack of knowlege...), and the x_val and y_val both have 3 channels and they are numpy arrays:

enter image description here

Please help me!!! Thanks a lot!

Solve this error and train my model.


Solution

  • It seems that you have 3 channels, but the shape of the output has only 1 channel.
    The UNET initialization has a parameter called classes, which defines the output shape. You should probably set this to classes=3 when definining the model.
    See documentation in https://segmentation-models.readthedocs.io/en/latest/api.html#unet
    And more specifically, note the following snippet:

    classes – a number of classes for output (output shape - (h, w, classes)).


    Clarification: The number of input channels is of no real consequence here. The channel count changes multiple times througout the convolutional layers, the skip-connection concatenations, etc., so the number of output channels vs the number of label channels is the only thing that matters in this context. train_x could have 1 channel, 37 channels, or any other necessary number of channels, and the error you're getting is only related to comparing y_train (and y_val) to the logits/softmax/label_estimations, or whatever you want to call them. Which is set, as mentioned above, by explicitly using the classes parameter.