Search code examples
pythontensorflowkerasloss-function

Use additional *trainable* variables in Keras/Tensorflow custom loss function


I know how to write a custom loss function in Keras with additional input, not the standard y_true, y_pred pair, see below. My issue is inputting the loss function with a trainable variable (a few of them) which is part of the loss gradient and should therefore be updated.

My workaround is:

  • Enter the network a dummy input of NXV size where N is the number of observations and V number of additional variables
  • Add a Dense() layer dummy_output so that Keras will track my V "weights"
  • Use this layer's V weights in my custom loss function for my true output layer
  • Use a dummy loss function (simply returns 0.0 and/or has weight 0.0) for this dummy_output layer so my V "weights" are only updated via my custom loss function

My question is: Is there a more natural Keras/TF-like way of doing this? Because it feels so contrived not to mention prone to bugs.

Example of my workaround:

(Yes I know this is a very silly custom loss function, in reality things are much more complex)

import numpy as np
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Dense
from tensorflow.keras.callbacks import EarlyStopping
import tensorflow.keras.backend as K
from tensorflow.keras.layers import Input
from tensorflow.keras import Model

n_col = 10
n_row = 1000
X = np.random.normal(size=(n_row, n_col))
beta = np.arange(10)
y = X @ beta

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# my custom loss function accepting my dummy layer with 2 variables
def custom_loss_builder(dummy_layer):
    def custom_loss(y_true, y_pred):
        var1 = dummy_layer.trainable_weights[0][0]
        var2 = dummy_layer.trainable_weights[0][1]
        return var1 * K.mean(K.square(y_true-y_pred)) + var2 ** 2 # so var2 should get to zero, var1 should get to minus infinity?
    return custom_loss

# my dummy loss function
def dummy_loss(y_true, y_pred):
    return 0.0

# my dummy input, N X V, where V is 2 for 2 vars
dummy_x_train = np.random.normal(size=(X_train.shape[0], 2)) 

# model
inputs = Input(shape=(X_train.shape[1],))
dummy_input = Input(shape=(dummy_x_train.shape[1],))
hidden1 = Dense(10)(inputs) # here only 1 hidden layer in the "real" network, assume whatever network is built here
output = Dense(1)(hidden1)
dummy_output = Dense(1, use_bias=False)(dummy_input)
model = Model(inputs=[inputs, dummy_input], outputs=[output, dummy_output])

# compilation, notice zero loss for the dummy_output layer
model.compile(
  loss=[custom_loss_builder(model.layers[-1]), dummy_loss],
  loss_weights=[1.0, 0.0], optimizer= 'adam')

# run, notice y_train repeating for dummy_output layer, it will not be used, could have created dummy_y_train as well
history = model.fit([X_train, dummy_x_train], [y_train, y_train],
                    batch_size=32, epochs=100, validation_split=0.1, verbose=0,
                   callbacks=[EarlyStopping(monitor='val_loss', patience=5)])

Seems to work as indeed whatever the start values for var1 and var2 (the initialization of the dummy_output layer) they aspire for minus inf and 0 respectively:

(this plot comes from running the model iteratively and saving those two weights like below)

var1_list = []
var2_list = []
for i in range(100):
    if i % 10 == 0:
        print('step %d' % i)
    model.fit([X_train, dummy_x_train], [y_train, y_train],
              batch_size=32, epochs=1, validation_split=0.1, verbose=0)
    var1, var2 = model.layers[-1].get_weights()[0]
    var1_list.append(var1.item())
    var2_list.append(var2.item())

plt.plot(var1_list, label='var1')
plt.plot(var2_list, 'r', label='var2')
plt.legend()
plt.show()

enter image description here


Solution

  • Answering my own question here, after days of struggling I got it to work without dummy input, I think this is much better and should be the "canonical" way until Keras/TF simplify the process. This is how the Keras/TF docs do it here.

    The key to using a loss function with external trainable variable is through working with a custom loss/output Layer which has self.add_loss(...) in its call() implementation, like so:

    class MyLoss(Layer):
        def __init__(self, var1, var2):
            super(MyLoss, self).__init__()
            self.var1 = K.variable(var1) # or tf.Variable(var1) etc.
            self.var2 = K.variable(var2)
        
        def get_vars(self):
            return self.var1, self.var2
        
        def custom_loss(self, y_true, y_pred):
            return self.var1 * K.mean(K.square(y_true-y_pred)) + self.var2 ** 2
        
        def call(self, y_true, y_pred):
            self.add_loss(self.custom_loss(y_true, y_pred))
            return y_pred
    

    Now notice the MyLoss layer needs two inputs, the actual y_true and the predicted y until that point:

    inputs = Input(shape=(X_train.shape[1],))
    y_input = Input(shape=(1,))
    hidden1 = Dense(10)(inputs)
    output = Dense(1)(hidden1)
    my_loss = MyLoss(0.5, 0.5)(y_input, output) # here can also initialize those var1, var2
    model = Model(inputs=[inputs, y_input], outputs=my_loss)
    
    model.compile(optimizer= 'adam')
    

    Finally as TF docs mention, in this case you do not have to specify the loss or y in the fit() function:

    history = model.fit([X_train, y_train], None,
                        batch_size=32, epochs=100, validation_split=0.1, verbose=0,
                        callbacks=[EarlyStopping(monitor='val_loss', patience=5)])
    

    Again, notice that y_train comes into fit() as one of the inputs.

    Now it works:

    var1_list = []
    var2_list = []
    for i in range(100):
        if i % 10 == 0:
            print('step %d' % i)
        model.fit([X_train, y_train], None,
                  batch_size=32, epochs=1, validation_split=0.1, verbose=0)
        var1, var2 = model.layers[-1].get_vars()
        var1_list.append(var1.numpy())
        var2_list.append(var2.numpy())
    
    plt.plot(var1_list, label='var1')
    plt.plot(var2_list, 'r', label='var2')
    plt.legend()
    plt.show()
    

    enter image description here

    (I should also mention this specific pattern of var1, var2 highly depends on their initial values, if var1's initial value is higher than 1 it will not in fact decrease until minus inf)