Even though the shape of both the input of the concatenation layer is same (as I've printed) there is an error showing different shapes.
class UNet(keras.Model):
def __init__(self, shape=(572, 572, 1), **kwargs):
self.concat = keras.layers.Concatenate(axis=-1) # concats through depth
...
class CONV2_BLOCK(keras.layers.Layer):
...
class CONV_T(keras.layers.Layer):
def __init__(self, filters, **kwargs):
super().__init__(**kwargs)
self.conv_t = keras.layers.Conv2DTranspose(filters=filters, kernel_size=2, strides=2)
def call(self, inputs):
outputs = self.conv_t(inputs)
return outputs
class CROP(keras.layers.Layer):
def __init__(self, cropping, **kwargs):
super().__init__(**kwargs)
self.cropping = cropping
self.crop = keras.layers.Cropping2D(cropping=self.cropping)
def call(self, inputs):
outputs = self.crop(inputs)
return outputs
def call(self, inputs):
# self.conv_arr = [64, 128, 256, 512, 1024]
# self.crop_arr = [4, 17, 40, 88] down to up
x1 = self.CONV2_BLOCK(filters=64)(inputs)
print(x1.shape)
x = self.maxpool(x1)
print(x.shape)
...
x = self.CONV2_BLOCK(filters=1024)(x)
print(x.shape)
print(f"convt shape{self.CONV_T(filters=512)(x).shape}")
print(f"crop shape{self.CROP(cropping=4)(x4).shape}")
x = self.concat([self.CONV_T(filters=512)(x), self.CROP(cropping=4)(x4)])
x = self.CONV2_BLOCK(filters=512)(x)
...
x = self.concat([self.CONV_T(filters=64)(x), self.CROP(cropping=88)(x1)])
x = self.CONV2_BLOCK(filters=64)(x)
outputs = self.conv_sz1(x)
return outputs
output of above code:
[conv_t shape(2, 56, 56, 512), crop shape(2, 56, 56, 512)] # printed
#error
--->83 x = self.concat([self.CONV_T(filters=216)(x), self.CROP(cropping=17)(x3)])
84 x = self.CONV2_BLOCK(filters=256)(x)
Dimension 1 in both shapes must be equal: shape[0] = [2,104,104,216] vs. shape[1] = [2,102,102,256] [Op:ConcatV2] name: concat
The following code works for me:
class UNet(keras.Model):
"""
argument: input_shape=(572, 572, 1) => default
"""
def __init__(self, shape=(572, 572, 1), **kwargs):
super().__init__(**kwargs)
self.shape = shape
self.maxpool = keras.layers.MaxPool2D(pool_size=2, strides=2)
self.concat = keras.layers.Concatenate(axis=-1) # concats through depth
self.conv_sz1 = keras.layers.Conv2D(filters=2, kernel_size=1, padding="same")
class CONV2_BLOCK(keras.layers.Layer):
def __init__(self, filters, **kwargs):
super().__init__(**kwargs)
self.filters = filters
self.conv1 = keras.layers.Conv2D(filters=self.filters, kernel_size=3, use_bias=False)
self.batchnorm = keras.layers.BatchNormalization()
self.relu = keras.layers.Activation(keras.activations.relu)
self.conv2 = keras.layers.Conv2D(filters=self.filters, kernel_size=3, use_bias=False)
def call(self, inputs):
x = self.conv1(inputs)
x = self.batchnorm(x)
x = self.relu(x)
x = self.conv2(x)
x = self.batchnorm(x)
outputs = self.relu(x)
return outputs
class CONV_T(keras.layers.Layer):
def __init__(self, filters, **kwargs):
super().__init__(**kwargs)
self.conv_t = keras.layers.Conv2DTranspose(filters=filters, kernel_size=2, strides=2)
def call(self, inputs):
outputs = self.conv_t(inputs)
return outputs
class CROP(keras.layers.Layer):
def __init__(self, cropping, **kwargs):
super().__init__(**kwargs)
self.crop = keras.layers.Cropping2D(cropping=cropping)
def call(self, inputs):
outputs = self.crop(inputs)
return outputs
def call(self, inputs):
# self.conv_arr = [64, 128, 256, 512, 1024]
# self.crop_arr = [4, 17, 40, 88] down to up
x1 = self.CONV2_BLOCK(filters=64)(inputs)
print(x1.shape)
x = self.maxpool(x1)
print(x.shape)
x2 = self.CONV2_BLOCK(filters=128)(x)
print(x2.shape)
x = self.maxpool(x2)
print(x.shape)
x3 = self.CONV2_BLOCK(filters=256)(x)
print(x3.shape)
x = self.maxpool(x3)
print(x.shape)
x4 = self.CONV2_BLOCK(filters=512)(x)
print(x4.shape)
x = self.maxpool(x4)
print(x.shape)
x = self.CONV2_BLOCK(filters=1024)(x)
print(x.shape)
x = self.concat([self.CONV_T(filters=512)(x), self.CROP(cropping=4)(x4)])
x = self.CONV2_BLOCK(filters=512)(x)
# line edited
x = self.concat([self.CONV_T(filters=256)(x), self.CROP(cropping=16)(x3)])
x = self.CONV2_BLOCK(filters=256)(x)
x = self.concat([self.CONV_T(filters=128)(x), self.CROP(cropping=40)(x2)])
x = self.CONV2_BLOCK(filters=128)(x)
x = self.concat([self.CONV_T(filters=64)(x), self.CROP(cropping=88)(x1)])
x = self.CONV2_BLOCK(filters=64)(x)
outputs = self.conv_sz1(x)
return outputs
What I've changed:
self.CONV_T(filters=216)(x)
to self.CONV_T(filters=256)(x)
self.CROP(cropping=17)(x3)
to self.CROP(cropping=16)(x3)
.