Search code examples
pythontensorflowsumloss

Tensorflow: How to ignore parts of an array when calling tf.reduce_sum


I want to change the typical MSE loss function. Right now I have the following code:

squared_difference = tf.reduce_sum(tf.square(target - output), [1])
mse_loss = tf.reduce_mean(squared_difference)

the shape of both the tensors is [batch_size, 10] and an example for the target is [0,1,2,3,0.5,0.5,0.5,7,8,9] . The 0.5s are always at the indices 4, 5 and 6.

What I want to do now is ignore these indices completly and don't increase the loss if the ouput of the network doesn't have 0.5 at these indices.

So if the ouput is [0,1,2,3,20,10,14,7,8,9] the loss should be 0.

What is the best possible way to achieve this?


Solution

  • There are many ways you can handle this. One straightforward way is to use the weights parameter of tf.losses.mean_squared_error. Pass a bsz x labels tensor which serves as a sort of mask with 1s for the values you want to consider and 0s to ignore. The weights parameter exists for most loss functions.