Search code examples
pythontensorflowmachine-learningneural-networktensorflow-probability

When training a variational Bayesian neural network in tfp, how can I visualize the evolution of the different terms in the loss separately?


I want to use tensorflow-probability to train a simple fully-connected Bayesian Neural Network. The loss is composed of KL terms and a negative log likelihood term. How can I see their separate evolution with tfp?

I have the following code:

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp

tfk = tf.keras
tfkl = tf.keras.layers
tfpl = tfp.layers
tfd = tfp.distributions

[make some data for a regression task]

input = tfkl.Input(n_features)
x = input
x = tfpl.DenseFlipout(100, activation='relu')(x)
x = tfpl.DenseFlipout(2)(x)
x = tfpl.DistributionLambda(lambda t: tfd.Normal(loc=t[..., :1],
                                                 scale=1e-3 + tf.math.softplus(t[..., 1:])))(x)

model = tfk.Model(input, x)

negloglik = lambda y, rv_y: -rv_y.log_prob(y)

model.compile(optimizer=tf.optimizers.Adam(), loss=negloglik, metrics=['mse'])
history = model.fit(x_train, y_train, epochs=10, validation_data=(x_val, y_val));

The loss function is the sum of the explicit term negloglik and a KL divergence term in each DenseFlipout layer (I can see those are there by looking at model.losses, for example).

How can I visualize each of these terms separately?


An attempt:

If I try to add a function that calculates negloglik to the metrics, such as

def negloglik_met(y_true, y_pred):
    return -y_pred.log_prob(y_true)

I get AttributeError: 'Tensor' object has no attribute 'log_prob' which is confusing to me. y_pred should be the output of the DistributionLambda layer, so why is it a Tensor and not a Distribution?

Something else I hoped would work but does not is adding model.losses[0] to the metrics. There I get ValueError: Could not interpret metric function identifier: Tensor("dense_flipout/divergence_kernel:0", shape=(), dtype=float32).


Solution

  • I drilled down in the TensorFlow code. It's due to the automatic TensorFlow creating an automatic wrapper around your (lambda) function. It casts and reshapes the model output (the distribution) to the type of the metric (which seems odd to me anyways). So, to prevent it, you should create your own wrapper, that doesn't perform this cast. The code that does this, is at: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/metrics.py#L583

    So inspire yourself on that block of code to make your own Metric Wrapper. This should be a feature of TFP.