Search code examples
pythontensorflowkerastensorflow-probability

Reparametrization in tensorflow-probability: tf.GradientTape() doesn't calculate the gradient with respect to a distribution's mean


In tensorflow version 2.0.0-beta1, I am trying to implement a keras layer which has weights sampled from a normal random distribution. I would like to have the mean of the distribution as trainable parameter.

Thanks to the "reparametrization trick" already implemented in tensorflow-probability, the calculation of the gradient with respect to the mean of the distribution should be possible in principle, if I am not mistaken.

However, when I try to calculate the gradient of the network output with respect to the mean value variable using tf.GradientTape(), the returned gradient is None.

I created two minimal examples, one of a layer with deterministic weights and one of a layer with random weights. The gradients of the deterministic layer's gradients are calculated as expected, but the gradients are None in case of the random layer. There is no error message giving details on why the gradient is None, and I am kind of stuck.

Minimal example code:

A: Here is the minimal example for the deterministic network:

import tensorflow as tf; print(tf.__version__)

from tensorflow.keras import backend as K
from tensorflow.keras.layers import Layer,Input
from tensorflow.keras.models import Model
from tensorflow.keras.initializers import RandomNormal
import tensorflow_probability as tfp

import numpy as np

# example data
x_data = np.random.rand(99,3).astype(np.float32)

# # A: DETERMINISTIC MODEL

# 1 Define Layer

class deterministic_test_layer(Layer):

    def __init__(self, output_dim, **kwargs):
        self.output_dim = output_dim
        super(deterministic_test_layer, self).__init__(**kwargs)

    def build(self, input_shape):
        self.kernel = self.add_weight(name='kernel', 
                                      shape=(input_shape[1], self.output_dim),
                                      initializer='uniform',
                                      trainable=True)
        super(deterministic_test_layer, self).build(input_shape)

    def call(self, x):
        return K.dot(x, self.kernel)

    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.output_dim)

# 2 Create model and calculate gradient

x = Input(shape=(3,))
fx = deterministic_test_layer(1)(x)
deterministic_test_model = Model(name='test_deterministic',inputs=[x], outputs=[fx])

print('\n\n\nCalculating gradients for deterministic model: ')

for x_now in np.split(x_data,3):
#     print(x_now.shape)
    with tf.GradientTape() as tape:
        fx_now = deterministic_test_model(x_now)
        grads = tape.gradient(
            fx_now,
            deterministic_test_model.trainable_variables,
        )
        print('\n',grads,'\n')

print(deterministic_test_model.summary())

B: The following example is very similar, but instead of deterministic weights I tried to use randomly sampled weights (randomly sampled at call() time!) for the test layer:

# # B: RANDOM MODEL

# 1 Define Layer

class random_test_layer(Layer):

    def __init__(self, output_dim, **kwargs):
        self.output_dim = output_dim
        super(random_test_layer, self).__init__(**kwargs)

    def build(self, input_shape):
        self.mean_W = self.add_weight('mean_W',
                                      initializer=RandomNormal(mean=0.5,stddev=0.1),
                                      trainable=True)

        self.kernel_dist = tfp.distributions.MultivariateNormalDiag(loc=self.mean_W,scale_diag=(1.,))
        super(random_test_layer, self).build(input_shape)

    def call(self, x):
        sampled_kernel = self.kernel_dist.sample(sample_shape=x.shape[1])
        return K.dot(x, sampled_kernel)

    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.output_dim)

# 2 Create model and calculate gradient

x = Input(shape=(3,))
fx = random_test_layer(1)(x)
random_test_model = Model(name='test_random',inputs=[x], outputs=[fx])

print('\n\n\nCalculating gradients for random model: ')

for x_now in np.split(x_data,3):
#     print(x_now.shape)
    with tf.GradientTape() as tape:
        fx_now = random_test_model(x_now)
        grads = tape.gradient(
            fx_now,
            random_test_model.trainable_variables,
        )
        print('\n',grads,'\n')

print(random_test_model.summary())

Expected/Actual Output:

A: The deterministic network works as expected, and the gradients are calculated. The output is:

2.0.0-beta1



Calculating gradients for deterministic model: 

 [<tf.Tensor: id=26, shape=(3, 1), dtype=float32, numpy=
array([[17.79845  ],
       [15.764006 ],
       [14.4183035]], dtype=float32)>] 


 [<tf.Tensor: id=34, shape=(3, 1), dtype=float32, numpy=
array([[16.22232 ],
       [17.09122 ],
       [16.195663]], dtype=float32)>] 


 [<tf.Tensor: id=42, shape=(3, 1), dtype=float32, numpy=
array([[16.382954],
       [16.074356],
       [17.718027]], dtype=float32)>] 

Model: "test_deterministic"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 3)]               0         
_________________________________________________________________
deterministic_test_layer (de (None, 1)                 3         
=================================================================
Total params: 3
Trainable params: 3
Non-trainable params: 0
_________________________________________________________________
None

B: However, in case of the similar random network, the gradients are not calculated as expected (using the reparametsization trick). Instead, they are None. The full output is

Calculating gradients for random model: 

 [None] 


 [None] 


 [None] 

Model: "test_random"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 3)]               0         
_________________________________________________________________
random_test_layer (random_te (None, 1)                 1         
=================================================================
Total params: 1
Trainable params: 1
Non-trainable params: 0
_________________________________________________________________
None

Can anybody point me at the problem here?


Solution

  • It seems that tfp.distributions.MultivariateNormalDiag is not differentiable with respect to its input parameters (e.g. loc). In this particular case, the following would be equivalent:

    class random_test_layer(Layer):
        ...
    
        def build(self, input_shape):
            ...
            self.kernel_dist = tfp.distributions.MultivariateNormalDiag(loc=0, scale_diag=(1.,))
            super(random_test_layer, self).build(input_shape)
    
        def call(self, x):
            sampled_kernel = self.kernel_dist.sample(sample_shape=x.shape[1]) + self.mean_W
            return K.dot(x, sampled_kernel)
    

    In this case, however, the loss is differentiable with respect to self.mean_W.

    Be careful: Although this approach might work for your purposes, note that calling the density function self.kernel_dist.prob would yield different results, since we took loc outside.