Search code examples
pythontensorflowkerastf.keras

UNet: InvalidArgumentError: Exception encountered when calling layer 'concatenate' (type Concatenate)


Even though the shape of both the input of the concatenation layer is same (as I've printed) there is an error showing different shapes.

class UNet(keras.Model):
  def __init__(self, shape=(572, 572, 1), **kwargs):
    self.concat = keras.layers.Concatenate(axis=-1) # concats through depth
    ...

  class CONV2_BLOCK(keras.layers.Layer):
    ...      

  class CONV_T(keras.layers.Layer):
    def __init__(self, filters, **kwargs):
      super().__init__(**kwargs)
      self.conv_t = keras.layers.Conv2DTranspose(filters=filters, kernel_size=2, strides=2)

    def call(self, inputs):
      outputs = self.conv_t(inputs)
      return outputs

  class CROP(keras.layers.Layer):
    def __init__(self, cropping, **kwargs):
      super().__init__(**kwargs)
      self.cropping = cropping
      self.crop = keras.layers.Cropping2D(cropping=self.cropping)

    def call(self, inputs):
      outputs = self.crop(inputs)
      return outputs

  def call(self, inputs):
    # self.conv_arr = [64, 128, 256, 512, 1024]
    # self.crop_arr = [4, 17, 40, 88] down to up

    x1 = self.CONV2_BLOCK(filters=64)(inputs)
    print(x1.shape)
    x = self.maxpool(x1)
    print(x.shape)
    ...

    x = self.CONV2_BLOCK(filters=1024)(x)
    print(x.shape)

    print(f"convt shape{self.CONV_T(filters=512)(x).shape}")
    print(f"crop shape{self.CROP(cropping=4)(x4).shape}")
    x = self.concat([self.CONV_T(filters=512)(x), self.CROP(cropping=4)(x4)])
    x = self.CONV2_BLOCK(filters=512)(x)

    ...

    x = self.concat([self.CONV_T(filters=64)(x), self.CROP(cropping=88)(x1)])
    x = self.CONV2_BLOCK(filters=64)(x)

    outputs = self.conv_sz1(x)

    return outputs

output of above code:

[conv_t shape(2, 56, 56, 512), crop shape(2, 56, 56, 512)] # printed

#error

--->83 x = self.concat([self.CONV_T(filters=216)(x), self.CROP(cropping=17)(x3)]) 84 x = self.CONV2_BLOCK(filters=256)(x)

Dimension 1 in both shapes must be equal: shape[0] = [2,104,104,216] vs. shape[1] = [2,102,102,256] [Op:ConcatV2] name: concat


Solution

  • The following code works for me:

    class UNet(keras.Model):
      """
      argument: input_shape=(572, 572, 1) => default
      """
      def __init__(self, shape=(572, 572, 1), **kwargs):
        super().__init__(**kwargs)
        self.shape = shape
        self.maxpool = keras.layers.MaxPool2D(pool_size=2, strides=2)
        self.concat = keras.layers.Concatenate(axis=-1) # concats through depth
        self.conv_sz1 = keras.layers.Conv2D(filters=2, kernel_size=1, padding="same")
    
      class CONV2_BLOCK(keras.layers.Layer):
        def __init__(self, filters, **kwargs):
          super().__init__(**kwargs)
          self.filters = filters
          self.conv1 = keras.layers.Conv2D(filters=self.filters, kernel_size=3, use_bias=False)
          self.batchnorm = keras.layers.BatchNormalization()
          self.relu = keras.layers.Activation(keras.activations.relu)
          self.conv2 = keras.layers.Conv2D(filters=self.filters, kernel_size=3, use_bias=False)
    
        def call(self, inputs):
          x = self.conv1(inputs)
          x = self.batchnorm(x)
          x = self.relu(x)
          x = self.conv2(x)
          x = self.batchnorm(x)
          outputs = self.relu(x)
          return outputs
    
    
      class CONV_T(keras.layers.Layer):
        def __init__(self, filters, **kwargs):
          super().__init__(**kwargs)
          self.conv_t = keras.layers.Conv2DTranspose(filters=filters, kernel_size=2, strides=2)
    
        def call(self, inputs):
          outputs = self.conv_t(inputs)
          return outputs
    
      class CROP(keras.layers.Layer):
        def __init__(self, cropping, **kwargs):
          super().__init__(**kwargs)
          self.crop = keras.layers.Cropping2D(cropping=cropping)
    
        def call(self, inputs):
          outputs = self.crop(inputs)
          return outputs
    
      def call(self, inputs):
        # self.conv_arr = [64, 128, 256, 512, 1024]
        # self.crop_arr = [4, 17, 40, 88] down to up
    
        x1 = self.CONV2_BLOCK(filters=64)(inputs)
        print(x1.shape)
        x = self.maxpool(x1)
        print(x.shape)
    
        x2 = self.CONV2_BLOCK(filters=128)(x)
        print(x2.shape)
        x = self.maxpool(x2)
        print(x.shape)
    
        x3 = self.CONV2_BLOCK(filters=256)(x)
        print(x3.shape)
        x = self.maxpool(x3)
        print(x.shape)
    
        x4 = self.CONV2_BLOCK(filters=512)(x)
        print(x4.shape)
        x = self.maxpool(x4)
        print(x.shape)
    
        x = self.CONV2_BLOCK(filters=1024)(x)
        print(x.shape)
    
        x = self.concat([self.CONV_T(filters=512)(x), self.CROP(cropping=4)(x4)])
        x = self.CONV2_BLOCK(filters=512)(x)
        
        # line edited
        x = self.concat([self.CONV_T(filters=256)(x), self.CROP(cropping=16)(x3)])
        x = self.CONV2_BLOCK(filters=256)(x)
    
        x = self.concat([self.CONV_T(filters=128)(x), self.CROP(cropping=40)(x2)])
        x = self.CONV2_BLOCK(filters=128)(x)
    
        x = self.concat([self.CONV_T(filters=64)(x), self.CROP(cropping=88)(x1)])
        x = self.CONV2_BLOCK(filters=64)(x)
    
        outputs = self.conv_sz1(x)
    
        return outputs
    

    What I've changed:

    • self.CONV_T(filters=216)(x) to self.CONV_T(filters=256)(x)
    • At the same line: self.CROP(cropping=17)(x3) to self.CROP(cropping=16)(x3).