Search code examples
pythonkerasmatrix-multiplicationkeras-layer

Dot pipeline data with constant matrix


Is it possible to multiply the batch in the middle of the pipeline with a constant transformation? Something along the lines of

constant_non_trainable_matrix = numpy.array([...]) # shape (n,n)

input = tf.keras.layers.InputLayer(shape = (n,))
dense_1 = tf.keras.layers.Dense((n,))(input)
transform = MultiplyWithMatrix(constant_non_trainable_matrix)(dense_1)
output = tf.keras.layers.Dense((n,))(transform)

model = tf.keras.models.Model(inputs = input, outputs = output)

Solution

  • You can use a Lambda layer and backend.dot() to achieve that:

    from keras import layers
    from keras import backend as K
    
    # ...
    transformed = layers.Lambda(lambda x: K.dot(x, mat))(dense_1)
    

    You need to construct the mat tensor using the backend functions as well (e.g. K.constant(), K.variable(), etc.).