Search code examples
pythontensorflowkerasloss-functioncross-entropy

Output of custom loss in Keras


I know there are many questions treating custom loss functions in Keras but I've been unable to answer this even after 3 hours of googling.

Here is a very simplified example of my problem. I realize this example is pointless but I provide it for simplicity, I obviously need to implement something more complicated.

from keras.backend import binary_crossentropy
from keras.backend import mean
def custom_loss(y_true, y_pred):

    zeros = tf.zeros_like(y_true)
    index_of_zeros = tf.where(tf.equal(zeros, y_true))
    ones = tf.ones_like(y_true)
    index_of_ones = tf.where(tf.equal(ones, y_true))

    zero = tf.gather(y_pred, index_of_zeros)
    one = tf.gather(y_pred, index_of_ones)

    loss_0 = binary_crossentropy(tf.zeros_like(zero), zero)
    loss_1 = binary_crossentropy(tf.ones_like(one), one)

    return mean(tf.concat([loss_0, loss_1], axis=0))

I do not understand why training the network with the above loss function on a two class dataset does not yield the same result as training with the built in binary-crossentropy loss function. Thank you!

EDIT: I edited the code snippet to include the mean as per comments below. I still get the same behavior however.


Solution

  • I finally figured it out. The tf.where function behaves very differently when the shape is "unknown". To fix the snippet above simply insert the following lines right after the function is declared:

    y_pred = tf.reshape(y_pred, [-1])
    y_true = tf.reshape(y_true, [-1])