Search code examples
pythontensorflowmachine-learningkerasfunctional-api

Custom loss function with multiple inputs for validation


I am creating a custom loss function following the instruction found here. When I add validation_data, I get an error message on ValueError. When I set validation_data=None, this error disappears. I found a similar question on Stackoverflow, but I think my issue is different because I am trying to use a custom loss function.

Here is my code:

from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
import numpy as np
import tensorflow.keras.backend as K
from tensorflow.keras import regularizers

def loss_fcn(y_true, y_pred, w):
    loss = K.mean(K.square((y_true-y_pred)*w))
    return loss

# since tensor flow sets the batch_size default to 32.  The number of samples have to be a multiple of 32 when it is great than 32.
data_x = np.random.rand(32, 51)
data_w = np.random.rand(32, 5)
data_y = np.random.rand(32, 5)

val_x = np.random.rand(4, 51)
val_w = np.random.rand(4, 5)
val_y = np.random.rand(4, 5)

input_x = Input(shape=(51,), name="input")
y_true = Input(shape=(5,), name="true_y")
w = Input(shape=(5,), name="weights")

out = Dense(128, kernel_regularizer=regularizers.l2(0.001), name="HL1")(input_x)
y = Dense(5, name="HL2", activation="tanh")(out)

model = Model(inputs=[input_x, y_true, w], outputs=y)
model.add_loss(loss_fcn(y_true, y, w))

model.compile()
model.fit((data_x, data_y, data_w), validation_data=(val_x, val_y, val_w))

The error message:

ValueError: Error when checking model input: the list of Numpy arrays that you are passing to your model is not the size the model expected. Expected to see 3 array(s), but instead got the following list of 1 arrays: [array([[0.74785946, 0.63599707, 0.45929641, 0.98855504, 0.84815295,
        0.28217452, 0.93502174, 0.23942027, 0.11885888, 0.32092279,
        0.47407394, 0.19737623, 0.85962504, 0.35906666, 0.22262...


Solution

  • Instead of tuples, make the training and validation data as list:

    model.fit([data_x, data_y, data_w], validation_data=[val_x, val_y, val_w])