Search code examples
tensorflowmaskconvolution

Tensorflow: Trainable Variable Masking


I am working on a convolutional neural net that requires some parts of the a kernel weights to be untrainable. tf.nn.conv2d(x, W) takes in a trainable variable W as weights. How can I make some of the elements of W to be untrainable?


Solution

  • Maybe you could have your trainable weights W1, a mask M indicating where the trainable variables are, and a constant / untrainable weight matrix W2, and use

    W = tf.multiply(W1, tf.cast(M, dtype=W1.dtype)) + tf.multiply(W2, tf.cast(tf.logical_not(M), dtype=W2.dtype))