Search code examples
pythontensorflowkerasgradient

Gradient of one layer w.r.t another layer when there is an input layer (and no value for the input)


I have a network written in tensorflow keras functional API. I'd like to use the gradient of one layer w.r.t to the previous layer as input for another layer. I tried gradient tape and tf.gradients and none of them worked. I get the following error:

ValueError: tf.function-decorated function tried to create variables on non-first call.

There is no input at this point and I have input layer.

Is it possible to do this in tenserflow?

My code:

def Geo_branch(self, geo_inp):
        Fully_Connected1 = layers.TimeDistributed(layers.Dense(128, activation='tanh'))(geo_inp)
        Fully_Connected2 = layers.TimeDistributed(layers.Dense(64, activation='tanh'))(Fully_Connected1)
        return Fully_Connected2

@tf.function
def geo_extension(self, geo_branch):
    Fully_Connected = layers.TimeDistributed(layers.Dense(100, activation='tanh'))(geo_branch)
    geo_ext = layers.LSTM(6,
                          activation="tanh",
                          recurrent_activation="sigmoid",
                          unroll=False,
                          use_bias=True,
                          name='Translation'
                          )(Fully_Connected)

    grads = tf.gradients(geo_ext, geo_branch)
    return geo_ext, grads

inp_geo = layers.Input(shape=(self.time_size, 6), name='geo_input')
Geo_branch = Geo_branch(inp_geo)
geo_ext, grads = geo_extension(Geo_branch)

Any solution is appreciated. It doesn't have to be GradientTape, if there is any other way to compute these gradients.


Solution

  • I would just inherit from tensorflow's Layer class and creating your own custom Layer. Also, it would probably be beneficial to put everything under one call so as to minimize the likelihood that there are disconnections in the graph.

    Example:

    import tensorflow as tf
    
    from typing import List
    from typing import Optional
    from typing import Tuple
    from tensorflow.keras import Model
    from tensorflow.keras.layers import Dense
    from tensorflow.keras.layers import Input
    from tensorflow.keras.layers import Layer
    from tensorflow.keras.layers import LSTM
    from tensorflow.keras.layers import TimeDistributed
    
    
    class CustomGeoLayer(Layer):
      """``CustomGeoLayer``."""
      def __init__(self, num_units: List[int], name: Optional[str] = None):
        super().__init__(name=name)
        self.num_units = num_units
        self.dense_0 = TimeDistributed(Dense(num_units[0], activation="tanh"))
        self.dense_1 = TimeDistributed(Dense(num_units[1], activation="tanh"))
        self.dense_2 = TimeDistributed(Dense(num_units[2], activation="tanh"))
        self.rnn = LSTM(units=num_units[3], activation="tanh",
                        recurrent_activation="sigmoid",
                        unroll=False, use_bias=True,
                        name="Translation")
        
      @tf.function
      def call(self,
               input_tensor: tf.Tensor,
               training: bool = True) -> Tuple[tf.Tensor, tf.Tensor]:
        x = self.dense_0(input_tensor)
        x = self.dense_1(x)
        r = self.dense_2(x)
        x = self.rnn(r, training=training)
        return x, tf.gradients(x, r)[0]
    
    
    # create model
    x_in = Input(shape=(10, 6))
    x_out = CustomGeoLayer([128, 64, 100, 6])(x_in)
    model = Model(x_in, x_out)
    
    # fake input data
    arr = tf.random.normal((3, 10, 6))
    
    # forward pass
    out, g = model(arr)
    
    print(out.shape)
    # (3, 6)
    
    print(g.shape)
    # (3, 10, 100)