Search code examples

Custom loss/objective function with additional variable input in Keras

I am trying to create a custom objective function in Keras (tensorflow backend) with an additional parameter whose value would depend on the batch being trained.


def myLoss(self, stateValues):
    def sparse_loss(y_true, y_pred):
        foo = tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_pred)
        return tf.reduce_mean(foo * stateValues)
    return sparse_loss

self.model.compile(loss=self.myLoss(stateValue = self.stateValue),

My train function is as follows

for batch in batches:
    self.stateValue = computeStateValueVectorForCurrentBatch(batch), yVals, batch_size=<num>)

However, the stateValue in the loss function is not being updated. It is just using the value stateValue has at model.compile step.

I guess this could be solved by using a placeHolder for stateValue but I am unable to figure out how to do it. Can someone please help?


  • Your loss function is not getting updated because keras doesn't compile the model after each batch and therefore is not using the updated loss function.

    You can define a custom callback which would update the value of loss after each batch. Something like this:

    from keras.callbacks import Callback
    class UpdateLoss(Callback):
        def on_batch_end(self, batch, logs={}):
            # I am not sure what is the type of the argument you are passing for computing stateValue ??
            stateValue = computeStateValueVectorForCurrentBatch(batch)
            self.model.loss = myLoss(stateValue)