Search code examples
pythontensorflowkerasloss-function

Keras custom loss with range selection - treat values independently


I have a custom loss defined as follows:

def custom_mse():
    def mse(y_true, y_pred):
        great = K.tf.greater(y_true,0.5)
        loss = K.square(tf.where(great, y_true, tf.zeros(tf.shape(y_true)))-tf.where(great, y_pred, tf.zeros(tf.shape(y_pred))))
        loss_low = K.square(y_true-y_pred)
        
        return loss+loss_low

This loss is summing two losses, MSE of untouched predictions and true labels, and MSE of elements only on positions where value in true label is over 0.5.

This works perfectly, but now I would like to do something different.

I want to create a loss that converts values in range 0.01-0.2 to ones, so the model gets feedback that those are incorrect (the importance to lower these values is higher). From this thread I found a way how to select the range:

lower_tensor = K.tf.greater(y_pred, 0.01)
upper_tensor = K.tf.less(y_pred, 0.2)
in_range = K.tf.logical_and(lower_tensor, upper_tensor)

I however can't find a way to apply these indices to original tensor, so in calculation it would be handled as ones. I want to do something like this (just an example, this wont work):

tf.where(in_range, y_pred) = 1
loss = K.square(y_true-y_pred)

Is there a way to achieve this? Or do I need to split the losses again and count them together?


Solution

  • You can use tf.where with tf.ones_like and your original tensor.

    An example (with y_pred set to [0.0, 0.1, 0.8, 0.15]) :

    import tensorflow as tf
    
    y_pred = tf.constant([0.0, 0.1, 0.8, 0.15])
    lower_tensor = tf.greater(y_pred, 0.01)
    upper_tensor = tf.less(y_pred, 0.2)
    in_range = tf.logical_and(lower_tensor, upper_tensor)
    # tf.where is (cond, tensor if cond is true, tensor if cond is false)
    y_pred_w_ones = tf.where(in_range, tf.ones_like(y_pred), y_pred)
    

    And we get

    >>> y_pred_w_ones
     <tf.Tensor: shape=(4,), dtype=float32, numpy=array([0. , 1. , 0.8, 1. ], dtype=float32)>