Search code examples
pythontensorflowkerasdeep-learningloss-function

How could I train a keras model with one output and multiple y_true?


I would like to train a Keras model with only one output, but there are multiple y_true. Like this:

def CustomLossFunc([y_true1, y_true2], y_pred):
    Score1 = func1(y_true1, y_pred)
    Score2 = func2(y_true2, y_pred)
    return Score1 + Score2 

Is it possible in Keras?

My thinking about this: Maybe I could split the only y_pred to two identical y_pred1 and y_pred2. Then train it like a model with two outputs. And assign two y_true to two y_pred. But I think it's a little messy, maybe there are better ways.


Solution

  • You could slice the y_true inside custom loss function when using model.fit:

    import tensorflow as tf
    from tensorflow import keras
    from tensorflow.keras import layers
    import numpy as np
    
    BS = 3
    N = 5
    X = np.random.rand(BS,N)
    Y = np.random.rand(BS,N*2)
    
    def CustomLossFunc(y_true, y_pred):
        y_true1 = y_true[:, :N]
        y_true2 = y_true[:, N:]
        Score1 = MAE(y_true1, y_pred)
        Score2 = MSE(y_true2, y_pred)
        return Score1 + Score2 
    
    def MAE(y_true, y_pred):
        return tf.reduce_mean(tf.abs(y_true - y_pred))
    
    def MSE(y_true, y_pred):
        return tf.reduce_mean((y_true - y_pred)**2.)
    
    input_shape = (N,)
    input_layer = keras.Input(shape=input_shape)
    output_layer = layers.Dense(N, 'relu')(input_layer)
    
    model = keras.Model(inputs=input_layer, outputs=output_layer)
    
    optimizer = keras.optimizers.Adam(learning_rate=1e-4)
    model.compile(optimizer=optimizer, loss=CustomLossFunc)
    
    model.fit(X, Y, batch_size=BS, epochs=1)
    

    If you are using GradientTape then just add loss together:

    tf.GradientTape() as tape:
        y_pred = model(X)
        loss = loss_1(y_true1, y_pred) + loss_2(y_true2, y_pred)
    
    grads = tape.gradient(loss , model.trainable_variables)
    optimizer.apply_gradients(zip(grads , model.trainable_variables))