Search code examples
pythontensorflowkerastensorloss-function

Is there a way to monitor the different subtensors of a custom Keras backend loss function?


I'm currently implementing a custom loss function by modelling it as a Keras backend tensor. The loss function has different parts (such as a classification loss, a quantization loss or a pairwise distance loss)

The code looks something like this:

...
different_class_loss = K.log(1 + K.exp(-1*dissimilarity + margin))
pair_loss = same_class * same_class_loss   +   (1-same_class) * different_class_loss

loss_value = lambda_pair * pair_loss + lambda_classif * classif_loss + lambda_quant_binary * quantization_loss

# Add loss to model
pairwise_model.add_loss(loss_value)

# Compile without specifying a loss
pairwise_model.compile(optimizer=optimizer_adam)

When I train the model using a batch_generator and pairwise_model.fit() the history contains exactly one loss argument for the combined loss_value. For debugging purposes I'd like to monitor every part of that loss function individually (i.e .quantization, classification and pairwise distance loss), but I can't figure out how.

I tried implementing a callback using K.eval() or K.print_tensor() to retrieve the values during training, but that didn't work. I also wasn't able to add multiple loss metrics using the add_loss function.

Is there a way to do this without writing a custom training loop? It feels like there should be. Any help is greatly appreciated.

__________________________________________________

EDIT:

Following the idea from Dr. Snoopy, here is the code that ended up working for me:

...
different_class_loss = K.log(1 + K.exp(-1*dissimilarity + margin))
pair_loss = same_class * same_class_loss   +   (1-same_class) * different_class_loss

loss_value = lambda_pair * pair_loss + lambda_classif * classif_loss + lambda_quant_binary * quantization_loss

# Add loss to model
pairwise_model.add_loss(loss_value)

# Add additional losses as metrics
pairwise_model.add_metric(pair_loss, name = "pairwise loss")
pairwise_model.add_metric(quantization_loss, name = "quantization loss")

# Compile without specifying a loss or metrics
pairwise_model.compile(optimizer=optimizer_adam)

Solution

  • You can pass them as metrics, like this:

    def pl():
        return pair_loss
    
    pairwise_model.compile(optimizer=optimizer_adam, metrics=[pl])
    

    And you can do similarly for your other loss components. The function might not be needed, you could also try passing pair_loss directly as a metric.