Search code examples
pythontensorflowneural-networkbayesiantensorflow-probability

Extract learned NN posterior weight distribution parameters from DenseVariational layer


I also posted this question in the tensorflow probability Github issues: https://github.com/tensorflow/probability/issues/892

I'm using Tensorflow 2.1.0 and tensorflow-probability 0.9.0 in python 3.6.8. I'm working with a Tensorflow Probability Keras model that has a DenseVariational layer defined as follows (lifted from examples found online):

def posterior_mean_field(kernel_size, bias_size=0, dtype=None):
    n = kernel_size + bias_size
    c = np.log(np.expm1(1.))
    return tf.keras.Sequential([
        tfp.layers.VariableLayer(2 * n, dtype=dtype),
        tfp.layers.DistributionLambda(lambda t: tfd.Independent(
            tfd.Normal(loc=t[..., :n], scale=1e-5 + tf.nn.softplus(c + t[..., n:])),
            reinterpreted_batch_ndims=1)),
    ])


def prior_trainable(kernel_size, bias_size=0, dtype=None):
    n = kernel_size + bias_size
    return tf.keras.Sequential([
        tfp.layers.VariableLayer(n, dtype=dtype),
        tfp.layers.DistributionLambda(lambda t: tfd.Independent(tfd.Normal(loc=t, scale=1),
                                                                reinterpreted_batch_ndims=1)),
    ])

dense = tfp.layers.DenseVariational(units=units, make_posterior_fn=posterior_mean_field,
                                                             make_prior_fn=prior_trainable,
                                                            )(prev_layer)

If I train my model and then remove the layers following this layer, the remaining model will output random variables from the learned posterior weight distributions. Something like this:

from tensorflow.keras import Model
# DenseVariational layer is 3rd to last layer in this case
cropped_model = Model(inputs, model.layers[-3].output)  
cropped_mode.predict(test_data)

Most of the time this is fine (e.g. training, sampling, etc.). However, is there a direct way to get the learned loc and scale posterior values returned for a given input (e.g. test_data) to this cropped_model, instead of a sample draw from the distribution they define?


Solution

  • You may refer to the 'Train model and Inspect' section of this webpage.

    I will briefly introduce the solution mentioned in the website here. Assuming the DenseVariational layer is the first layer of your trainned model, you can get the trainned prior distribution and then its mean and variance in this way (since DenseVariational layer is not affected by input, the dummy input can be any array:

    dummy_input = np.array([[0]])
    model.layers[0]._prior(dummy_input)
    print('Prior Variance: ', model_prior.variance().numpy())
    print('Posterior mean: ', model_posterior.mean().numpy())