So I am training a variation of a Unet style network in Tensorflow for a problem I am trying to solve. I have noticed an interesting pattern / error that I am unable to comprehend or fix.
As I have been training this network, on tensorboard the training loss is greater than validation loss, but the metric for validation is very low.(below)
But I have been looking at the output data from the network, and honestly, the output doesn't appear "half bad", at least not something that's a Dice of .25-.30
So when I externally validate the Dice by reloading the model and predicting on the validation set, I get a high dice score of > .90.
I have a feeling this is due to my loss and metrics utilized, but am unsure how to proceed. My loss metrics, and external validation metric code blocks are posted below.
Loss Class
class sce_dsc(losses.Loss):
def __init__(self, scale_sce=1.0, scale_dsc=1.0, sample_weight = None, epsilon=0.01, name=None):
super(sce_dsc, self).__init__()
self.sce = losses.SparseCategoricalCrossentropy(from_logits=False) #while the last layer activation is sigmoid, logits needs to be false
self.epsilon = epsilon
self.scale_a = scale_sce
self.scale_b = scale_dsc
self.cls = 1
self.weights = sample_weight
def dsc(self, y_true, y_pred, sample_weight = None):
true = tf.cast(y_true[..., 0] == self.cls, tf.int64)
pred = tf.nn.softmax(y_pred, axis=-1)[..., self.cls]
if self.weights is not None:
#true = true * (sample_weight[...])
true = true & (sample_weight[...] !=0)
#pred = pred * (sample_weight[...])
pred = pred & (sample_weight[...] !=0)
A = tf.math.reduce_sum(tf.cast(true, tf.float32) * tf.cast(pred,tf.float32)) * 2
B = tf.cast(tf.math.reduce_sum(true), tf.float32) + tf.cast(tf.math.reduce_sum(pred),tf.float32) + self.epsilon
return (1.0 - A/B)
def call(self, y_true, y_pred):
sce_loss = self.sce(y_true=y_true, y_pred=y_pred, sample_weight=self.weights) * self.scale_a
dsc_loss = self.dsc(y_true=y_true, y_pred=y_pred, sample_weight=self.weights) * self.scale_b
loss = tf.cast(sce_loss, tf.float32) + tf.cast(dsc_loss,tf.float32)
#self.add_loss(loss)
return loss```
Metric Class
class custom_dice(keras.metrics.Metric):
def __init__(self, name = "dsc", **kwargs):
super(custom_dice,self).__init__(**kwargs)
self.dice = self.add_weight(name = 'dice_coef', initializer = 'zeros')
def update_state(self, y_true,y_pred, sample_weight = None):
true = tf.cast(y_true[...,0] == 1, tf.int64)
pred = tf.math.argmax(y_pred == 1 , axis=-1)
if sample_weight is not None:
true = true * (sample_weight[...])
pred = pred * (sample_weight[...])
A = tf.math.count_nonzero(true & pred) * 2
B = tf.math.count_nonzero(true) + tf.math.count_nonzero(pred)
value = tf.math.divide_no_nan(tf.cast(A, tf.float32),tf.cast(B, tf.float32))
self.dice.assign(value)
def result(self):
return self.dice
def reset_state(self):
self.dice.assign(0.0)
External Validation Dice
def dsc(y_true, y_pred, sample_weight=None, c = 1):
print(y_true.shape, y_pred.shape)
true = tf.cast(y_true[...,0] == 1, tf.int64)
pred = tf.math.argmax(y_pred== c , axis=-1)
print(true.shape,pred.shape)
if sample_weight is not None:
true = true * (sample_weight[...])
pred = pred * (sample_weight[...])
A = tf.math.count_nonzero(true & pred) * 2
B = tf.math.count_nonzero(true) + tf.math.count_nonzero(pred)
return A / B
The metric above runs into an issue of calculating NaN's, or essentially 0 if the network does not predict anything on slices where there are none of the positive class. The rewritten code below fixes the issue:
def dice(self, y_true,y_pred, epsilon = p['epsilon']):
y_pred_arg = tf.math.argmax(y_pred, axis = -1)
y_true_f = tf.cast(K.flatten(y_true), tf.int64)
y_pred_f = tf.cast(K.flatten(y_pred_arg), tf.int64)
intersection = tf.cast(K.sum(y_true_f * y_pred_f), tf.float32)
dice = (2 * intersection + epsilon) / (tf.cast(K.sum(y_true_f), tf.float32) + tf.cast(K.sum(y_pred_f), tf.float32) + epsilon)
return tf.cast(dice, tf.float32)
the epsilon is a smoothing factor. This helps prevent a situation to divide by 0. I personally found epsilon = 1e-2 to have the best results on my current network, but this is definitely a hyper-parameter that should be optimized for training.