Search code examples
tensorflowkerasautoencoderlosscustom-training

How to create joint loss with paired Dataset samples in Tensorflow Keras API?


I'm trying to train an autoencoder, with constraints that force one or more of the hidden/encoded nodes/neurons to have an interpretable value. My training approach uses paired images (though after training the model should operate on a single image) and utilizes a joint loss function that includes (1) the reconstruction loss for each of the images and (2) a comparison between values of the hidden/encoded vector, from each of the two images.

I've created an analogous simple toy problem and model to make this clearer. In the toy problem, the autoencoder is given a vector of length 3 as input. The encoding uses one dense layer to compute the mean (a scalar) and another dense layer to compute some other representation of the vector (given my construction, it will likely just learn an identity matrix, i.e., copy the input vector). See the figure below. The lowest node of the hidden layer is intended to compute the mean of the input vector. The rest of the hidden nodes are unconstrained aside from having to accommodate a reconstruction that matches the input.

Toy model

The figure below exhibits how I wish to train the model, using paired images. "MSE" is mean-squared-error, although the identity of the actual function is not important for the question I'm asking here. The loss function is the sum of the reconstruction loss and the mean-estimation loss.

Toy model training

I've tried creating (1) a tf.data.Dataset to generate paired vectors, (2) a Keras model, and (3) a custom loss function. However, I'm failing to understand how to do this correctly for this particular situation.

I can't get the Model.fit() to run correctly, and to associate the model outputs with the Dataset targets as intended. See code and errors below. Can anyone help? I've done many Google and stackoverflow searches and still don't understand how I can implement this.

import tensorflow as tf
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 

DTYPE = tf.dtypes.float32
N_VEC = 3

def my_generator(n):
    while True:
        # Create two identical vectors of length, except with different means.
        # An internal layer (single neuron) of the model should predict the
        # mean of the input vector. To train it to do so, with paired
        # vector inputs, use a loss function that penalizes incorrect
        # predictions of the difference of the means of two input vectors.
        input_vec1 = tf.random.normal((n,), dtype=DTYPE)
        target_mean_diff = tf.random.normal((1,), dtype=DTYPE)
        input_vec2 = input_vec1 + target_mean_diff
        
        # Model is a constrained autoencoder. Output targets are
        # identical to the input vectors. Including them as explicit
        # targets in this generator, for generalization.
        target_vec1 = tf.identity(input_vec1)
        target_vec2 = tf.identity(input_vec2)
        
        yield ({'input_vec1':input_vec1,
                'input_vec2':input_vec2},
               {'target_vec1':target_vec1,
                'target_vec2':target_vec2,
                'target_mean_diff':target_mean_diff})

def my_dataset(n, batch_size=4):
    ds = tf.data.Dataset.from_generator(my_generator,
                                        output_signature=({'input_vec1':tf.TensorSpec(shape=(n,), dtype=DTYPE),
                                                           'input_vec2':tf.TensorSpec(shape=(n,), dtype=DTYPE)},
                                                          {'target_vec1':tf.TensorSpec(shape=(n,), dtype=DTYPE),
                                                           'target_vec2':tf.TensorSpec(shape=(n,), dtype=DTYPE),
                                                           'target_mean_diff':tf.TensorSpec(shape=(1,), dtype=DTYPE)}),
                                        args=(n,))
    ds = ds.batch(batch_size)    
    return ds


## Do a brief test using the Dataset
ds = my_dataset(N_VEC, batch_size=4)
ds_iter = iter(ds)
dict_inputs, dict_targets = next(ds_iter)
print(dict_inputs)
print(dict_targets)


## Define the Model
layer_encode_vec = tf.keras.layers.Dense(N_VEC, activation=None, name='encode_vec')
layer_decode_vec = tf.keras.layers.Dense(N_VEC, activation=None, name='decode_vec')
layer_encode_mean = tf.keras.layers.Dense(1, activation=None, name='encode_mean')
layer_decode_mean = tf.keras.layers.Dense(N_VEC, activation=None, name='decode_mean')

input1 = tf.keras.Input(shape=(N_VEC,), name='input_vec1')
input2 = tf.keras.Input(shape=(N_VEC,), name='input_vec2')
vec_encoded1 = layer_encode_vec(input1)
vec_encoded2 = layer_encode_vec(input2)
mean_encoded1 = layer_encode_mean(input1)
mean_encoded2 = layer_encode_mean(input2)
mean_diff = mean_encoded2 - mean_encoded1
pred_vec1 = layer_decode_vec(vec_encoded1) + layer_decode_mean(mean_encoded1)
pred_vec2 = layer_decode_vec(vec_encoded2) + layer_decode_mean(mean_encoded2)

model = tf.keras.Model(inputs=[input1, input2], outputs=[pred_vec1, pred_vec2, mean_diff])

print(model.summary())


## Define the joint loss function
def loss_total(y_true, y_pred):
    loss_reconstruct = tf.reduce_mean(tf.keras.MSE(y_true[0], y_pred[0]))/2 + \
                       tf.reduce_mean(tf.keras.MSE(y_true[1], y_pred[1]))/2
    loss_mean = tf.reduce_mean(tf.keras.MSE(y_true[2], y_pred[2]))
    return loss_reconstruct + loss_mean


## Compile model
optimizer = tf.keras.optimizers.Adam(lr=0.01)
model.compile(optimizer=optimizer, loss=loss_total)


## Train model
history = model.fit(x=ds, epochs=10, steps_per_epoch=10)

Output: Example batch from the Dataset:

{'input_vec1': <tf.Tensor: shape=(4, 3), dtype=float32, numpy=
array([[-0.53022575, -0.02389329,  0.32843253],
       [-0.61793506, -0.8276422 , -1.3469328 ],
       [-0.5401968 ,  0.3141346 , -1.3638284 ],
       [-1.2189807 ,  0.23848908,  0.75108534]], dtype=float32)>, 'input_vec2': <tf.Tensor: shape=(4, 3), dtype=float32, numpy=
array([[-0.23415083,  0.27218163,  0.6245074 ],
       [-0.57636774, -0.7860749 , -1.3053654 ],
       [ 0.65463066,  1.508962  , -0.16900098],
       [-0.49326736,  0.9642024 ,  1.4767987 ]], dtype=float32)>}
{'target_vec1': <tf.Tensor: shape=(4, 3), dtype=float32, numpy=
array([[-0.53022575, -0.02389329,  0.32843253],
       [-0.61793506, -0.8276422 , -1.3469328 ],
       [-0.5401968 ,  0.3141346 , -1.3638284 ],
       [-1.2189807 ,  0.23848908,  0.75108534]], dtype=float32)>, 'target_vec2': <tf.Tensor: shape=(4, 3), dtype=float32, numpy=
array([[-0.23415083,  0.27218163,  0.6245074 ],
       [-0.57636774, -0.7860749 , -1.3053654 ],
       [ 0.65463066,  1.508962  , -0.16900098],
       [-0.49326736,  0.9642024 ,  1.4767987 ]], dtype=float32)>, 'target_mean_diff': <tf.Tensor: shape=(4, 1), dtype=float32, numpy=
array([[0.29607493],
       [0.04156734],
       [1.1948274 ],
       [0.7257133 ]], dtype=float32)>}

Output: The model summary:

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_vec1 (InputLayer)         [(None, 3)]          0                                            
__________________________________________________________________________________________________
input_vec2 (InputLayer)         [(None, 3)]          0                                            
__________________________________________________________________________________________________
encode_vec (Dense)              (None, 3)            12          input_vec1[0][0]                 
                                                                 input_vec2[0][0]                 
__________________________________________________________________________________________________
encode_mean (Dense)             (None, 1)            4           input_vec1[0][0]                 
                                                                 input_vec2[0][0]                 
__________________________________________________________________________________________________
decode_vec (Dense)              (None, 3)            12          encode_vec[0][0]                 
                                                                 encode_vec[1][0]                 
__________________________________________________________________________________________________
decode_mean (Dense)             (None, 3)            6           encode_mean[0][0]                
                                                                 encode_mean[1][0]                
__________________________________________________________________________________________________
tf.__operators__.add (TFOpLambd (None, 3)            0           decode_vec[0][0]                 
                                                                 decode_mean[0][0]                
__________________________________________________________________________________________________
tf.__operators__.add_1 (TFOpLam (None, 3)            0           decode_vec[1][0]                 
                                                                 decode_mean[1][0]                
__________________________________________________________________________________________________
tf.math.subtract (TFOpLambda)   (None, 1)            0           encode_mean[1][0]                
                                                                 encode_mean[0][0]                
==================================================================================================
Total params: 34
Trainable params: 34
Non-trainable params: 0
__________________________________________________________________________________________________

Output: The error message when calling model.fit():

Epoch 1/10
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)

...

ValueError: Found unexpected keys that do not correspond to any
Model output: dict_keys(['target_vec1', 'target_vec2', 'target_mean_diff']).
Expected: ['tf.__operators__.add', 'tf.__operators__.add_1', 'tf.math.subtract']

Solution

  • You can pass a dict to Model for both inputs and outputs like so:

    model = tf.keras.Model(
        inputs={"input_vec1": input1, "input_vec2": input2},
        outputs={
            "target_vec1": pred_vec1,
            "target_vec2": pred_vec2,
            "target_mean_diff": mean_diff,
        },
    )
    

    which avoids having to name the output layers.

    For the losses, it's currently applying loss_total to each of the 3 outputs individually and summing to get the final loss, which is not what you want. So you can either break out each of the losses individually:

    model.compile(
        optimizer=optimizer,
        loss={"target_vec1": "mse", "target_vec2": "mse", "target_mean_diff": "mse"},
        loss_weights={"target_vec1": 0.5, "target_vec2": 0.5, "target_mean_diff": 1},
    )
    

    or you can manually train the model using a modified loss function that takes dict input. Something like:

    def loss_total(y_true, y_pred):
        loss_reconstruct = (
            tf.reduce_mean(tf.keras.losses.MSE(y_true["target_vec1"], y_pred["target_vec1"])) / 2
            + tf.reduce_mean(tf.keras.losses.MSE(y_true["target_vec2"], y_pred["target_vec2"])) / 2
        )
        loss_mean = tf.reduce_mean(tf.keras.losses.MSE(y_true["target_mean_diff"], y_pred["target_mean_diff"]))
        return loss_reconstruct + loss_mean
    
    for epoch in range(10):
        for batch, (x, y) in zip(range(10), ds):
            with tf.GradientTape() as tape:
                outputs = model(x, training=True)
                loss = loss_total(y, outputs)
    
            trainable_vars = model.trainable_variables
            gradients = tape.gradient(loss, trainable_vars)
            optimizer.apply_gradients(zip(gradients, trainable_vars))
            print(f"Batch: {batch}, loss: {loss.numpy()}")