I have a model like the one below. I want to add a matrix of learnable weights in the end, which is initialized to the variable matrix that I pass to the function create_model.
To get the intuitive idea of what I want to do, imagine the matrix is supposed to be the one I pass to the model, but I have the feeling that it can still be finetuned during training. Therefore, I want it to be initialized to the values I pass, and then refined during training.
The code below works, but as you see from the model.summary() output, the matrix multiplication contains no learnable weights, which makes me think that the weights of the matrix are not beeing finetuned.
What am I doing wrong?
def create_model(num_columns, matrix):
inp_layer = tfl.Input((num_columns,))
dense = tfl.Dense(512, activation = 'relu')(inp_layer)
dense = tfl.Dense(256, activation = 'relu')(dense)
dense = tfl.Dense(128, activation = 'relu')(dense)
va = tf.Variable(matrix, dtype = tf.float32)
dense = K.dot(dense, va )
model = tf.keras.Model(inputs = inp_layer, outputs = dense)
model.compile(optimizer='adam', loss=['binary_crossentropy'])
model.summary()
return model
matrix = np.random.randint(0,2,(128, 206)) # In reality, this is not random, but it has sensed values
num_columns = 750
model = create_model(num_columns,matrix)
you can simply use a dense layer with no bias to do this multiplication. After the model is built I change the weight of interest with the matrix you provided
def create_model(num_columns, matrix):
inp_layer = Input((num_columns,))
x = Dense(512, activation = 'relu')(inp_layer)
x = Dense(256, activation = 'relu')(x)
x = Dense(128, activation = 'relu')(x)
dense = Dense(206, use_bias=False)(x)
model = Model(inputs = inp_layer, outputs = dense)
model.compile(optimizer='adam', loss=['binary_crossentropy'])
model.set_weights(model.get_weights()[:-1] + [matrix])
model.summary()
return model
matrix = np.random.randint(0,2,(128, 206)) # In reality, this is not random, but it has sensed values
num_columns = 750
model = create_model(num_columns,matrix)
check
(model.get_weights()[-1] == matrix).all() # True
In this way, the weights can be fine-tuned