I'm trying to train a UNet, but for some reason I get the following error:
Traceback (most recent call last):
File "<ipython-input-54-b56497e81356>", line 1, in <module>
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=opt.learning_rate), loss=dice_coef_loss, metrics=[dice_coef])
File "C:\Users\...\Anaconda3\envs\gpuconda\lib\site-packages\tensorflow\python\keras\engine\training.py", line 537, in compile
with self.distribute_strategy.scope():
File "C:\Users\...\Anaconda3\envs\gpuconda\lib\site-packages\tensorflow\python\keras\engine\training.py", line 682, in distribute_strategy
return self._distribution_strategy or ds_context.get_strategy()
AttributeError: 'UNet' object has no attribute '_distribution_strategy'
I already looked for multiple answers of which one was replacing keras
into tf.keras
, but the error still arises. Another anwer was that it had to do with TensorBoard, so I removed the Tensorboard callback, but that also didn't fix it.
This is my model:
import tensorflow as tf
class UNet(tf.keras.Model):
def __init__(self, img_shape=(256,256,256), num_class=1):
print ('build UNet ...')
self.img_shape = img_shape+(1,)
self.num_class = num_class
def get_crop_shape(self, target, refer):
# depth, the 4th dimension
cd = (target.get_shape()[3] - refer.get_shape()[3])
assert (cd >= 0)
if cd % 2 != 0:
cd1, cd2 = int(cd//2), int(cd//2) + 1
cd1, cd2 = int(cd//2), int(cd//2)
# width, the 3rd dimension
cw = (target.get_shape()[2] - refer.get_shape()[2])
assert (cw >= 0)
if cw % 2 != 0:
cw1, cw2 = int(cw//2), int(cw//2) + 1
cw1, cw2 = int(cw//2), int(cw//2)
# height, the 2nd dimension
ch = (target.get_shape()[1] - refer.get_shape()[1])
assert (ch >= 0)
if ch % 2 != 0:
ch1, ch2 = int(ch//2), int(ch//2) + 1
ch1, ch2 = int(ch//2), int(ch//2)
return (ch1, ch2), (cw1, cw2), (cd1, cd2)
def __call__(self, inputs):
concat_axis = 4
conv1 = tf.keras.layers.Conv3D(8, (3, 3, 3), activation='relu', padding='same', name='conv1_1')(inputs)
conv1 = tf.keras.layers.Conv3D(8, (3, 3, 3), activation='relu', padding='same')(conv1)
pool1 = tf.keras.layers.MaxPooling3D(pool_size=(2, 2, 2))(conv1)
conv2 = tf.keras.layers.Conv3D(16, (3, 3, 3), activation='relu', padding='same')(pool1)
conv2 = tf.keras.layers.Conv3D(16, (3, 3, 3), activation='relu', padding='same')(conv2)
up_conv1 = tf.keras.layers.UpSampling3D(size=(2, 2, 2))(conv2)
ch, cw, cd = self.get_crop_shape(conv1, up_conv1)
crop_conv1 = tf.keras.layers.Cropping3D(cropping=(ch,cw,cd))(conv1)
up1 = tf.keras.layers.concatenate([up_conv1, crop_conv1], axis=concat_axis)
conv3 = tf.keras.layers.Conv3D(8, (3, 3, 3), activation='relu', padding='same')(up1)
conv3 = tf.keras.layers.Conv3D(8, (3, 3, 3), activation='relu', padding='same')(conv3)
ch, cw, cd = self.get_crop_shape(inputs, conv3)
conv3 = tf.keras.layers.ZeroPadding3D(padding=((ch[0], ch[1]), (cw[0], cw[1]), (cd[0], cd[1])))(conv3)
conv4 = tf.keras.layers.Conv3D(self.num_class, (1, 1, 1), activation="sigmoid")(conv3)
return conv4
And this is how I'm trying to train it:
# Initialize the model
model = UNet(img_shape=opt.img_shape, num_class=opt.num_class)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=opt.learning_rate), loss=dice_coef_loss, metrics=[dice_coef])
callbacks = [
tf.keras.callbacks.ModelCheckpoint(os.path.join(opt.checkpoint_path, "UNet.{epoch:02d}-{val_loss:.3f}.hdf5")),
tf.keras.callbacks.TensorBoard(log_dir=opt.log_dir, histogram_freq=1, write_graph=True, write_images=True, update_freq='epoch')
# Train the model, doing validation at the end of each epoch.
model.fit(trainDataset, epochs=opt.epoch, validation_data=testDataset, callbacks=callbacks)
I'm not using a distribution strategy. I use python version 3.7.9 and tensorflow version 2.3.0.
Can someone help me with this?
I found the answer, fortunately.
I forgot to call the superclass of the tensorflow keras model; I just had to add super(MyModel, self).__init__()
to my __init__()
So my model looks like this now:
class UNet(tf.keras.Model):
def __init__(self, img_shape=(256,256,256), num_class=1):
super(UNet, self).__init__()
print ('build UNet ...')
self.img_shape = img_shape
self.num_class = num_class