Search code examples

how to use gradientTape() with non trainable variables

Say I have a 2-layer neural network and I want to make the second layer non trainable. So I initiate these variables

w1 = tf.random.truncated_normal([28*28, 256])
b1 = tf.zeros([256])
w2 = tf.random.truncated_normal([256, 50])
b2 = tf.zeros([10])

and train them.

for (x,y) in db:
    x = tf.reshape(x, [-1, 28*28])
    with tf.GradientTape() as tape:
        h1 = x@w1 + tf.broadcast_to(b1, [x.shape[0], 256])
        h1 = tf.nn.relu(h1)
        h2 = h1@w2 + tf.broadcast_to(b2, [x.shape[0], 10])
        out = tf.nn.relu(h2)
        y_onehot = tf.one_hot(y, depth=10)
        loss = tf.square(y_onehot - out)
        loss = tf.reduce_mean(loss)

The problem is, with GradientTape() enclosing all variables, w2 and b2 are also trainable. How to make them non trainable?


  • Your assumption that all variables inside the GradientTape will be trainable is incorrect.

    Gradients are only computed for the variables that you pass to the gradient function as second parameter:

    tape.gradient(tensor, <your trainable variables here>).