Search code examples
pythontensorflowkerasloss-function

Keras/Tensorflow: Combined Loss function for single output


I have only one output for my model, but I would like to combine two different loss functions:

def get_model():
    # create the model here
    model = Model(inputs=image, outputs=output)

    alpha = 0.2
    model.compile(loss=[mse, gse],
                      loss_weights=[1-alpha, alpha]
                      , ...)

but it complains that I need to have two outputs because I defined two losses:

ValueError: When passing a list as loss, it should have one entry per model outputs. 
The model has 1 outputs, but you passed loss=[<function mse at 0x0000024D7E1FB378>, <function gse at 0x0000024D7E1FB510>]

Can I possibly write my final loss function without having to create another loss function (because that would restrict me from changing the alpha outside the loss function)?

How do I do something like (1-alpha)*mse + alpha*gse?


Update:

Both my loss functions are equivalent to the function signature of any builtin keras loss function, takes in y_true and y_pred and gives a tensor back for loss (which can be reduced to a scalar using K.mean()), but I believe, how these loss functions are defined shouldn't affect the answer as long as they return valid losses.

def gse(y_true, y_pred):
    # some tensor operation on y_pred and y_true
    return K.mean(K.square(y_pred - y_true), axis=-1)

Solution

  • Specify a custom function for the loss:

    model = Model(inputs=image, outputs=output)
    
    alpha = 0.2
    model.compile(
        loss=lambda y_true, y_pred: (1 - alpha) * mse(y_true, y_pred) + alpha * gse(y_true, y_pred),
        ...)
    

    Or if you don't want an ugly lambda make it into an actual function:

    def my_loss(y_true, y_pred):
        return (1 - alpha) * mse(y_true, y_pred) + alpha * gse(y_true, y_pred)
    
    model = Model(inputs=image, outputs=output)
    
    alpha = 0.2
    model.compile(loss=my_loss, ...)
    

    EDIT:

    If your alpha is not some global constant, you can have a "loss function factory":

    def make_my_loss(alpha):
        def my_loss(y_true, y_pred):
            return (1 - alpha) * mse(y_true, y_pred) + alpha * gse(y_true, y_pred)
        return my_loss
    
    model = Model(inputs=image, outputs=output)
    
    alpha = 0.2
    my_loss = make_my_loss(alpha)
    model.compile(loss=my_loss, ...)