Search code examples
pythontensorflowkeras

Apply different loss function to part of tensor in keras


I'm trying to build a custom loss function where it will apply different function to different part of tensor based on groundtruth.

Say for example the groundtruth is:

[0 1 1 0]

I want to apply log(n) to index 1, 2 (which is those whose value is 1 in the ground truth) of the output tensor, and apply log(n-1) to the rest.

How will I be able to achieve it?


Solution

  • You can create two masks.

    • The first one masks out zeros, so you can apply it to your first loss function in which you only apply log(n) to those values of 1's.

    • The second mask masks out ones, so you can apply it to your second loss function in which you apply log(n-1) to those values of 0's.

    Something like:

    input = tf.constant([0, 1, 1, 0], tf.float32)
    mask1 = tf.cast(tf.equal(input, 1.0), tf.float32)
    loss1 = tf.log(input) * mask1
    
    mask2 = tf.cast(tf.equal(input, 0.0), tf.float32)
    loss2 = tf.log(input - 1) * mask2
    
    overall_loss = tf.add(loss1, loss2)