Say I have a 2-layer neural network and I want to make the second layer non trainable. So I initiate these variables
w1 = tf.random.truncated_normal([28*28, 256])
b1 = tf.zeros([256])
w2 = tf.random.truncated_normal([256, 50])
b2 = tf.zeros([10])
and train them.
for (x,y) in db:
x = tf.reshape(x, [-1, 28*28])
with tf.GradientTape() as tape:
h1 = x@w1 + tf.broadcast_to(b1, [x.shape[0], 256])
h1 = tf.nn.relu(h1)
h2 = h1@w2 + tf.broadcast_to(b2, [x.shape[0], 10])
out = tf.nn.relu(h2)
y_onehot = tf.one_hot(y, depth=10)
loss = tf.square(y_onehot - out)
loss = tf.reduce_mean(loss)
The problem is, with GradientTape() enclosing all variables, w2 and b2 are also trainable. How to make them non trainable?
Your assumption that all variables inside the GradientTape
will be trainable is incorrect.
Gradients are only computed for the variables that you pass to the gradient
function as second parameter:
tape.gradient(tensor, <your trainable variables here>).