Search code examples
compilationkerastriplet

How does keras pass y_pred to loss object/function via model.compile


If I have a function that defines the triple loss (which expects a y_true & y_pred as input parameters), and I "reference or call it" via the following:

model.compile(optimizer="rmsprop", loss=triplet_loss, metrics=[accuracy]) 

Hows does the y_pred get passed to the triplet_loss function?

For example the triplet_loss function may be:

def triplet_loss(y_true, y_pred, alpha = 0.2):
    """
    Implementation of the triplet loss function
    Arguments:
    y_true -- true labels, required when you define a loss in Keras, 
    y_pred -- python list containing three objects:
    """
    anchor, positive, negative = y_pred[0], y_pred[1], y_pred[2]
    # distance between the anchor and the positive
    pos_dist = tf.reduce_sum(tf.square(tf.subtract(anchor,positive)))
    # distance between the anchor and the negative
    neg_dist = tf.reduce_sum(tf.square(tf.subtract(anchor,negative)))
    # compute loss
    basic_loss = pos_dist-neg_dist+alpha
    loss = tf.maximum(basic_loss,0.0)
    return loss

Thanks Jon


Solution

  • I did a little bit of poking through the keras source code. In the Model() class:

    First they modify the function a bit to take into account weights:

    self.loss_functions = loss_functions
    weighted_losses = [_weighted_masked_objective(fn) for fn in loss_functions]
    

    A bit later during training they map their outputs (predictions) to their targets (labels) and call the loss function to get the output_loss. Here y_true and y_pred are passed into your function.

    y_true = self.targets[i]
    y_pred = self.outputs[i]
    weighted_loss = weighted_losses[i]
    sample_weight = sample_weights[i]
    mask = masks[i]
    loss_weight = loss_weights_list[i]
    with K.name_scope(self.output_names[i] + '_loss'):
        output_loss = weighted_loss(y_true, y_pred,
                                    sample_weight, mask)