I am trying to implement a discriminative loss function for instance segmentation of images based on this paper: https://arxiv.org/pdf/1708.02551.pdf (This link is just for the readers' reference; I don't expect anyone to read it to help me out!)
My problem: Once I move from a simple loss function to a more complicated one (like you see in the attached code snippet), the loss function zeroes out after the first epoch. I checked the weights, and almost all of them seem to hover closely around -300. They are not exactly identical, but very close to each other (differing only in the decimal places).
Relevant code that implements the discriminative loss function:
def regDLF(y_true, y_pred):
global alpha
global beta
global gamma
global delta_v
global delta_d
global image_height
global image_width
global nDim
y_true = tf.reshape(y_true, [image_height*image_width])
X = tf.reshape(y_pred, [image_height*image_width, nDim])
uniqueLabels, uniqueInd = tf.unique(y_true)
numUnique = tf.size(uniqueLabels)
Sigma = tf.unsorted_segment_sum(X, uniqueInd, numUnique)
ones_Sigma = tf.ones((tf.shape(X)[0], 1))
ones_Sigma = tf.unsorted_segment_sum(ones_Sigma,uniqueInd, numUnique)
mu = tf.divide(Sigma, ones_Sigma)
Lreg = tf.reduce_mean(tf.norm(mu, axis = 1))
T = tf.norm(tf.subtract(tf.gather(mu, uniqueInd), X), axis = 1)
T = tf.divide(T, Lreg)
T = tf.subtract(T, delta_v)
T = tf.clip_by_value(T, 0, T)
T = tf.square(T)
ones_Sigma = tf.ones_like(uniqueInd, dtype = tf.float32)
ones_Sigma = tf.unsorted_segment_sum(ones_Sigma,uniqueInd, numUnique)
clusterSigma = tf.unsorted_segment_sum(T, uniqueInd, numUnique)
clusterSigma = tf.divide(clusterSigma, ones_Sigma)
Lvar = tf.reduce_mean(clusterSigma, axis = 0)
mu_interleaved_rep = tf.tile(mu, [numUnique, 1])
mu_band_rep = tf.tile(mu, [1, numUnique])
mu_band_rep = tf.reshape(mu_band_rep, (numUnique*numUnique, nDim))
mu_diff = tf.subtract(mu_band_rep, mu_interleaved_rep)
mu_diff = tf.norm(mu_diff, axis = 1)
mu_diff = tf.divide(mu_diff, Lreg)
mu_diff = tf.subtract(2*delta_d, mu_diff)
mu_diff = tf.clip_by_value(mu_diff, 0, mu_diff)
mu_diff = tf.square(mu_diff)
numUniqueF = tf.cast(numUnique, tf.float32)
Ldist = tf.reduce_mean(mu_diff)
L = alpha * Lvar + beta * Ldist + gamma * Lreg
return L
Question: I know it's hard to understand what the code does without reading the paper, but I have a couple questions:
Is there something glaringly wrong with the loss function defined above?
Anyone has a general idea as to why the loss function could zero out after the first epoch?
Thank you very much for your time and help!
I think your problem suffers from tf.norm which is not safe (leads to zeros somewhere in the vector and hence nan in its gradients). It would be better to replace tf.norm by this custom function:
def tf_norm(inputs, axis=1, epsilon=1e-7, name='safe_norm'):
squared_norm = tf.reduce_sum(tf.square(inputs), axis=axis, keep_dims=True)
safe_norm = tf.sqrt(squared_norm+epsilon)
return tf.identity(safe_norm, name=name)