Search code examples
tensorflowdeep-learningkerasimage-segmentationloss

TensorFlow loss function zeroes out after first epoch


I am trying to implement a discriminative loss function for instance segmentation of images based on this paper: https://arxiv.org/pdf/1708.02551.pdf (This link is just for the readers' reference; I don't expect anyone to read it to help me out!)

My problem: Once I move from a simple loss function to a more complicated one (like you see in the attached code snippet), the loss function zeroes out after the first epoch. I checked the weights, and almost all of them seem to hover closely around -300. They are not exactly identical, but very close to each other (differing only in the decimal places).

Relevant code that implements the discriminative loss function:

def regDLF(y_true, y_pred):
    global alpha
    global beta
    global gamma
    global delta_v
    global delta_d
    global image_height
    global image_width
    global nDim

    y_true = tf.reshape(y_true, [image_height*image_width])

    X = tf.reshape(y_pred, [image_height*image_width, nDim])
    uniqueLabels, uniqueInd = tf.unique(y_true)

    numUnique = tf.size(uniqueLabels)

    Sigma = tf.unsorted_segment_sum(X, uniqueInd, numUnique)
    ones_Sigma = tf.ones((tf.shape(X)[0], 1))
    ones_Sigma = tf.unsorted_segment_sum(ones_Sigma,uniqueInd, numUnique)
    mu = tf.divide(Sigma, ones_Sigma)

    Lreg = tf.reduce_mean(tf.norm(mu, axis = 1))

    T = tf.norm(tf.subtract(tf.gather(mu, uniqueInd), X), axis = 1)
    T = tf.divide(T, Lreg)
    T = tf.subtract(T, delta_v)
    T = tf.clip_by_value(T, 0, T)
    T = tf.square(T)

    ones_Sigma = tf.ones_like(uniqueInd, dtype = tf.float32)
    ones_Sigma = tf.unsorted_segment_sum(ones_Sigma,uniqueInd, numUnique)
    clusterSigma = tf.unsorted_segment_sum(T, uniqueInd, numUnique)
    clusterSigma = tf.divide(clusterSigma, ones_Sigma)

    Lvar = tf.reduce_mean(clusterSigma, axis = 0)

    mu_interleaved_rep = tf.tile(mu, [numUnique, 1])
    mu_band_rep = tf.tile(mu, [1, numUnique])
    mu_band_rep = tf.reshape(mu_band_rep, (numUnique*numUnique, nDim))

    mu_diff = tf.subtract(mu_band_rep, mu_interleaved_rep)
    mu_diff = tf.norm(mu_diff, axis = 1)
    mu_diff = tf.divide(mu_diff, Lreg)

    mu_diff = tf.subtract(2*delta_d, mu_diff)
    mu_diff = tf.clip_by_value(mu_diff, 0, mu_diff)
    mu_diff = tf.square(mu_diff)

    numUniqueF = tf.cast(numUnique, tf.float32)
    Ldist = tf.reduce_mean(mu_diff)        

    L = alpha * Lvar + beta * Ldist + gamma * Lreg

    return L

Question: I know it's hard to understand what the code does without reading the paper, but I have a couple questions:

  1. Is there something glaringly wrong with the loss function defined above?

  2. Anyone has a general idea as to why the loss function could zero out after the first epoch?

Thank you very much for your time and help!


Solution

  • I think your problem suffers from tf.norm which is not safe (leads to zeros somewhere in the vector and hence nan in its gradients). It would be better to replace tf.norm by this custom function:

    def tf_norm(inputs, axis=1, epsilon=1e-7,  name='safe_norm'):
        squared_norm    = tf.reduce_sum(tf.square(inputs), axis=axis, keep_dims=True)
        safe_norm       = tf.sqrt(squared_norm+epsilon)
        return tf.identity(safe_norm, name=name)