Search code examples
pythontensorflowdeep-learningneural-networkloss-function

Act on condition in custom loss function


I'm defining a model in tensorflow, and as loss function I coded my own custom loss function. In general, I put a penalty term in addition to the normal MSE function:

return K.mean(K.square(y_true - y_pred)) + penalty

where this penalty is given by products and sums between some inputs elements and some targets (I do not report the code because it is unrelated to the question).

Now, everything works perfectly, but I need to add to the dataset a new kind of examples, where some features and output have zero value.

So, taking in the batch (during the training) both type of examples, i.e. with zero values and without them, I need a kind of conditional statement applied on the entire batch, so that I can compute the penalty in two different ways based on the evaluation of the condition.

Can it be done ?


Solution

  • You can use tf.where :

    For example:

    • If element is zero, returns itself
    • If element is not zero, returns istself + 1
    >>> batch_1 = tf.constant([0.,1.,0.,0.5,3.,-2.,0.])
    >>> tf.where(batch_1 == 0, batch_1, batch_1+1)
     <tf.Tensor: shape=(7,), dtype=float32, numpy=array([0. , 2. , 0. , 1.5, 4. , -1. , 0. ], dtype=float32)>
    

    It works too, with multiple dimensions tensors, for example:

    • If second element of the batch is zero, sum the first element
    • Otherwise, sum the whole batch
    >>> a = tf.constant([[1,0,3],[3,4,9],,[8,0,1]],tf.float32)
    >>> tf.where(a[:,1]==0,tf.reduce_sum(a[:,:1],axis=1), tf.reduce_sum(a,axis=1))
    <tf.Tensor: shape=(3,), dtype=float32, numpy=array([ 1., 16.,  8.], dtype=float32)>