I have this piece of code defined in the model_utils.py
local module.
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ["SM_FRAMEWORK"] = "tf.keras"
import segmentation_models as sm
import matplotlib.pyplot as plt
import tensorflow_io as tfio
import tensorflow as tf
from tensorflow import keras
keras.backend.clear_session() #uvolneni RAM
BACKBONE = 'efficientnetb0'
n_classes = 1
activation = 'sigmoid'
model = sm.Unet(BACKBONE, classes=n_classes, activation=activation)
class DisplayCallback(keras.callbacks.Callback):
def __init__(self, dataset, epoch_interval=5):
self.dataset = dataset
self.epoch_interval = epoch_interval
def display(self, display_list, extra_title=''):
plt.figure(figsize=(15, 15))
title = ['Input Image', 'True Mask', 'Predicted Mask']
if len(display_list) > len(title):
title.append(extra_title)
for i in range(len(display_list)):
plt.subplot(1, len(display_list), i+1)
plt.title(title[i])
plt.imshow(display_list[i], cmap='gray')
plt.axis('off')
plt.show()
def create_mask(self, pred_mask):
pred_mask = (pred_mask > 0.5).astype("int32")
return pred_mask[0]
def show_predictions(self, dataset, num=1):
for image, mask in dataset.take(num):
pred_mask = model.predict(image)
self.display([image[0], mask[0], self.create_mask(pred_mask)])
def on_epoch_end(self, epoch, logs=None):
if epoch and epoch % self.epoch_interval == 0:
self.show_predictions(self.dataset)
print ('\nSample Prediction after epoch {}\n'.format(epoch+1))
I imported it with from model_utils import *
into a jupyter notebook stored in the same directory. And used it as a callback for training of my model.
model.compile(optim, bce, metrics)
with tf.device('/gpu:0'):
model.fit(
train_dataset,
callbacks=[DisplayCallback(train_dataset)], #additionally cp_callback
epochs=400
)
Doing so I got this error as soon as the training approached the moment of using the DisplayCallback class.
NameError: name 'model' is not defined
Why is this? How can i resolve it? Please be kind and thanks a lot for any help in advance!
The problem is not with importing, but within DisplayCallback
. The class doesn't recognize what model
is.
Fixes:
Inside __init__
, call super().__init__()
. This method defines the self.model
attribute.
Inside your class, every time you refer to model
, it must be self.model
. I found one mistake in show_predictions
but I could've missed a few more.
Here's a reference of how to implement custom callbacks, scroll down to the part where self.model
is used.