Search code examples
pythontensorflowdeep-learninggradient-descent

tf.gradients to tf.GradientTape


I have the following code in one part of my program:

inverse = tf.gradients(x_conv, x, x_conv)[0]
reconstruction_loss = tf.nn.l2_loss(inverse - tf.stop_gradient(x))

where x_conv is a Tensor (float32) of shape (384, 24, 1051) and x is a Tensor (float32) with shape (4, 3, 32, 4201). I am trying to change from using tf.gradients because in order to use that I need to disable eager execution and that seems to mess up a lot of my other operations.

One suggestion from Tensorflow is to use tf.GradientTape() instead but I did not seem to find an example that sets the initial gradient as x_conv as well, which from my understanding is what the original code is doing.

I have tried the following, using random data for reproducibility. However, I am getting 'None' for inverse. I am also unsure of how to rewrite the part with tf.stop_gradient.

data = tf.random.uniform((4,3,16800), dtype=tf.float32)

with tf.GradientTape() as tape:
  x = data
  shape_input = x.get_shape().as_list()
  shape_fast = [np.prod(shape_input[:-1]), 1, shape_input[-1]]
  kernel_size = 1794
  paddings = [0, 0], [0, 0], [kernel_size // 2 - 1, kernel_size // 2 + 1]
  filters_kernel = tf.random.uniform((1794, 1, 16), dtype=tf.float32)
  x_reshape = tf.reshape(x, shape_fast)
  x_pad = tf.pad(x_reshape, paddings=paddings, mode='SYMMETRIC')
  x_conv = tf.nn.conv1d(x_pad, filters_kernel, stride=2,
                              padding='VALID', data_format='NCW')
inverse = tape.gradient(x_conv, x, output_gradients=x_conv)

Does anyone know how I could possibly rewrite this part or are there any other functions that I could use? I am working on Tensorflow 2.11.0.

For more references, the full code is in https://github.com/leonard-seydoux/scatnet/blob/master/scatnet/layer.py and the particular section this problem relates to is from line 218 to 220.


Solution

  • Just add tape.watch(data) like

    ...
    with tf.GradientTape() as tape:
      tape.watch(data)  
      x = data
      ...