Search code examples
pythontensorflowkerasloss-function

Tensorflow: Data dependent loss function


I am trying to implement a loss function that computes a loss depending on the (unaugmented) data. So far I found an example detailing the process using the model.add_loss() method of a tf.keras.models.Model() here, but I struggle to implement it.

I have a tf.Dataset object containing my data, labels, and the data dependent variable for every sample calculated before augmentation (let's call it z). The data dependent variable is what I want to pass to my custom loss function.

I am dropping the ball in trying to pass the predictions, label and z to my loss function when calling it with model.add_loss.

Given a simple model like:

import tensorflow as tf
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.models import Model
from tensorflow.keras.losses import Loss
import numpy as np
 
data = Input(shape=(1,), dtype=tf.float32)
label = Input(shape=(3,), dtype=tf.float32)
z = Input(shape=(1,), dtype=tf.float32)

out = Dense(3)(data)

m = Model(inputs=[data, label, z], outputs=out)

def my_loss(y_true, y_pred, z):
    cce = tf.keras.losses.CategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
    cce_loss = cce(y_true, y_pred)
    return tf.reduce_mean(tf.multiply(cce_loss, z))
  
m.add_loss(my_loss(label, out, z))
 
m.compile(loss=None, optimizer='adam')

dataset = tf.data.Dataset.from_tensor_slices(([1, 2, 3], [[1, 0, 0], [0, 1, 0], [0, 0, 1]], [0.1, 0.2, 0.3]))

m.fit(dataset, epochs=10)

Trying to run this, I get: ValueError: Layer "model_17" expects 3 input(s), but it received 1 input tensors.

Is there a way to use an input array [data, label, z] with a tf.dataset object? Or how do I access the three different values inside the model, if I just pass the dataset object as one input value?


Solution

  • One of the ways to accomplish that is to use the zip function from tf.data.Dataset (Reading more on TensorFlow website). Here is the modified code:

    import tensorflow as tf
    from tensorflow.keras.layers import Dense, Input
    from tensorflow.keras.models import Model
    from tensorflow.keras.losses import Loss
    import numpy as np
     
    data = Input(shape=(1,), dtype=tf.float32)
    label = Input(shape=(3,), dtype=tf.float32)
    z = Input(shape=(1,), dtype=tf.float32)
    
    out = Dense(3)(data)
    
    m = Model(inputs=[data, label, z], outputs=out)
    
    def my_loss(y_true, y_pred, z):
        cce = tf.keras.losses.CategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
        cce_loss = cce(y_true, y_pred)
        return tf.reduce_mean(tf.multiply(cce_loss, z))
      
    m.add_loss(my_loss(label, out, z))
    m.compile(loss=None, optimizer='adam')
    
    #original code:
    #dataset = tf.data.Dataset.from_tensor_slices(([1, 2, 3], [[1, 0, 0], [0, 1, 0], [0, 0, 1]], [0.1, 0.2, 0.3]))
    
    #----------------------------------------------
    # modified code
    #----------------------------------------------
    inp_data=tf.data.Dataset.from_tensor_slices([1, 2, 3])
    inp_label=tf.data.Dataset.from_tensor_slices([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
    inp_z=tf.data.Dataset.from_tensor_slices([0.1, 0.2, 0.3])
    
    """
    Making a dummy `y` for dataset even though it is not used here due to having `loss=None`. 
    Note that `inp_label` from input is still passed to the `add_loss` method.
    """
    target = tf.data.Dataset.from_tensor_slices(tf.zeros(3))
    
    input_zip = tf.data.Dataset.zip((inp_data, inp_label, inp_z))
    dataset = tf.data.Dataset.zip((input_zip, target))
    
    dataset = dataset.batch(3)
    #----------------------------------------------
    
    m.fit(dataset, epochs=3)
    

    Output: enter image description here


    Note

    It is important to examine the code as well as the model architecture carefully to ensure the desired functionality is implemented correctly. One of the helpful tools is the plot_model function from tf.keras.utils:

    tf.keras.utils.plot_model(
        m,
        to_file='model.png',
        show_shapes=True,
        show_layer_names=True,
    )
    

    Output: enter image description here