I have recently started learning about Semantic Segmentation. I am trying to train a UNet for the same. My input is RGB 128x128x3 images. My masks are made up of 4 classes 0, 1, 2, 3 and are One-Hot Encoded with dimension 128x128x4.
def weighted_cce(y_true, y_pred):
weights = []
t_inf = tf.convert_to_tensor(1e9, dtype = 'float32')
t_zero = tf.convert_to_tensor(0, dtype = 'int64')
for i in range(0, 4):
l = tf.argmax(y_true, axis = -1) == i
n = tf.cast(tf.math.count_nonzero(l), 'float32') + K.epsilon()
weights.append(n)
weights = [batch_size/j for j in weights]
y_pred /= K.sum(y_pred, axis=-1, keepdims=True)
# clip to prevent NaN's and Inf's
y_pred = K.clip(y_pred, K.epsilon(), 1 - K.epsilon())
# calc
loss = y_true * K.log(y_pred) * weights
loss = -K.sum(loss, -1)
return loss
This is the loss function that I am using but it classifies every pixel as 2. What am I doing wrong?
You should have weights based on you entire data (unless your batch size is reasonably big so you have sort of stable weights).
If some class is underrepresented, with a small batch size, it will have near infinity weights.
If your target data is numpy array:
shp = y_train.shape
totalPixels = shp[0] * shp[1] * shp[2]
weights = np.sum(y_train, axis=(0, 1, 2)) #final shape (4,)
weights = totalPixels/weights
If your data is in a Sequence
generator:
totalPixels = 0
counts = np.zeros((4,))
for i in range(len(generator)):
x, y = generator[i]
shp = y.shape
totalPixels += shp[0] * shp[1] * shp[2]
counts = counts + np.sum(y, axis=(0,1,2))
weights = totalPixels / counts
If your data is in a yield
generator (you must know how many batches you have in an epoch):
for i in range(batches_per_epoch):
x, y = next(generator)
#the rest is equal to the Sequence example above
I don't know if newer versions of Keras are able to handle this, but you can try the simplest approach first: simply call fit
or fit_generator
with the class_weight
argument:
model.fit(...., class_weight = {0: weights[0], 1: weights[1], 2: weights[2], 3: weights[3]})
Make a healthier loss function:
weights = weights.reshape((1,1,1,4))
kWeights = K.constant(weights)
def weighted_cce(y_true, y_pred):
yWeights = kWeights * y_pred #shape (batch, 128, 128, 4)
yWeights = K.sum(yWeights, axis=-1) #shape (batch, 128, 128)
loss = K.categorical_crossentropy(y_true, y_pred) #shape (batch, 128, 128)
wLoss = yWeights * loss
return K.sum(wLoss, axis=(1,2))