I want to use tensorflow-probability to train a simple fully-connected Bayesian Neural Network. The loss is composed of KL terms and a negative log likelihood term. How can I see their separate evolution with tfp?
I have the following code:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
tfk = tf.keras
tfkl = tf.keras.layers
tfpl = tfp.layers
tfd = tfp.distributions
[make some data for a regression task]
input = tfkl.Input(n_features)
x = input
x = tfpl.DenseFlipout(100, activation='relu')(x)
x = tfpl.DenseFlipout(2)(x)
x = tfpl.DistributionLambda(lambda t: tfd.Normal(loc=t[..., :1],
scale=1e-3 + tf.math.softplus(t[..., 1:])))(x)
model = tfk.Model(input, x)
negloglik = lambda y, rv_y: -rv_y.log_prob(y)
model.compile(optimizer=tf.optimizers.Adam(), loss=negloglik, metrics=['mse'])
history = model.fit(x_train, y_train, epochs=10, validation_data=(x_val, y_val));
The loss function is the sum of the explicit term negloglik
and a KL divergence term in each DenseFlipout
layer (I can see those are there by looking at model.losses
, for example).
How can I visualize each of these terms separately?
An attempt:
If I try to add a function that calculates negloglik
to the metrics, such as
def negloglik_met(y_true, y_pred):
return -y_pred.log_prob(y_true)
I get AttributeError: 'Tensor' object has no attribute 'log_prob'
which is confusing to me. y_pred
should be the output of the DistributionLambda
layer, so why is it a Tensor and not a Distribution?
Something else I hoped would work but does not is adding model.losses[0]
to the metrics. There I get ValueError: Could not interpret metric function identifier: Tensor("dense_flipout/divergence_kernel:0", shape=(), dtype=float32)
.
I drilled down in the TensorFlow code. It's due to the automatic TensorFlow creating an automatic wrapper around your (lambda) function. It casts and reshapes the model output (the distribution) to the type of the metric (which seems odd to me anyways). So, to prevent it, you should create your own wrapper, that doesn't perform this cast. The code that does this, is at: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/metrics.py#L583
So inspire yourself on that block of code to make your own Metric Wrapper. This should be a feature of TFP.