Search code examples
pythontensorflowkeras

Extracting an item from a TensorFlow InputLayer


I am building a neural network with an input layer, a dense layer and an output layer. I'd like to use a number from the input vector in my custom loss function.

Let's say I have the following loss function:

def custom_loss_wrapper(input_layer):
    def custom_loss(y_true, y_pred):
        return tf.keras.metrics.mean_squared_error(y_true, y_pred)  # + 10th item of the input vector (input_layer.output[9]?)
    return custom_loss

The sequential model looks like this:

model = tf.keras.models.Sequential(
    [
        input_layer:=tf.keras.layers.InputLayer(shape=(x.shape[1],)),
        first_layer:=tf.keras.layers.Dense(128, activation="relu"),
        output_layer:=tf.keras.layers.Dense(y.shape[1], activation="sigmoid"),
    ]
)

model.compile(
    optimizer="adam", loss=custom_loss_wrapper(input_layer), metrics=["mae"], run_eagerly=True, 
)

model.fit(x=train_input, y=train_output, epochs=5, validation_split=0.5)

Unfortunately I do not know how I can access e.g., the 10th item in the input vector. I tried to use input_layer.output[9] and many other approaches, but none worked. Thank you for your help!

Versions: Keras==3.0.0, TensorFlow==2.15, Python 3.11


Solution

  • I think the best way to do this would be to use a custom subclassed model and a custom training loop. In the model.call function, you would return the values you want to add from the output of the first layer, and then create a custom loss which accepts three arguments.

    This would do it if you change your input dataset, your specific model, and the column you want to add.

    import tensorflow as tf
    
    x = tf.data.Dataset.from_tensor_slices(tf.random.uniform((32, 4)))
    y = tf.data.Dataset.from_tensor_slices(tf.random.uniform((32, 1)))
    
    ds = tf.data.Dataset.zip((x, y)).batch(8)
    
    
    class MyModel(tf.keras.Model):
        def __init__(self):
            super(MyModel, self).__init__()
            self.d0 = tf.keras.layers.Dense(16, activation='relu')
            self.d1 = tf.keras.layers.Dense(32, activation='relu')
            self.d2 = tf.keras.layers.Dense(1, activation='linear')
    
        def call(self, inputs, training=None, **kwargs):
            out_layer_1 = self.d0(inputs)
            out_layer_2 = self.d1(out_layer_1)
            out_layer_3 = self.d2(out_layer_2)
            return out_layer_3, tf.expand_dims(out_layer_1[:, 3], axis=0) 
            # select desired column above, here it's 3
    
    
    model = MyModel()
    
    def loss_function(y_true, y_pred, to_add): 
        return tf.keras.metrics.mean_squared_error(y_true, y_pred) + to_add
    
    
    # sanity check
    # batch_x, batch_y = next(iter(ds))
    # out_y, out_to_add = model(batch_x)
    # loss_function(batch_y, out_y, out_to_add)
    
    optimizer = tf.keras.optimizers.Adam()
    
    loss = tf.keras.metrics.Mean(name='loss')
    
    
    @tf.function
    def train_step(inputs, targets):
        with tf.GradientTape() as tape:
            predictions, layer_outputs = model(inputs)
            run_loss = loss_function(targets, predictions, layer_outputs)
        gradients = tape.gradient(run_loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        loss(run_loss)
    
    
    for epoch in range(10):
        for data, labels in ds:
            train_step(data, labels)
    
        template = 'Epoch {:>2}, Loss: {:>7.4f}'
        print(template.format(epoch + 1, loss.result()))
        loss.reset_states()