Here's the code
tfd = tfp.distributions
tfb = tfp.bijectors
# A common choice for a normalizing flow is to use a Gaussian for the base
# distribution. (However, any continuous distribution would work.) E.g.,
nvp = tfd.TransformedDistribution(
distribution=tfd.MultivariateNormalDiag(loc=[0., 0., 0.]),
bijector=tfb.RealNVP(
num_masked=2,
shift_and_log_scale_fn=tfb.real_nvp_default_template(
hidden_layers=[512, 512])))
x = nvp.sample((32,32))
x = nvp.sample((32,32))
gives me a tensor with 32x32x3
shape . But when throwing the x
into nvp.log_prob(x)
, I get a 32x32
shape tensor. I was expecting a (1,)
like tensor since I want to get log_prob
of this 32,32,3
tensor.
So the problem is, how to tinker the code above to calculate log_prob
of a 32x32x3
-shape tensor?
RNVP transforms vector-valued distributions (i.e. MVNDiag in your case above). You can try nvp.distribution.log_prob(x)
(apply the underlying distribution's log_prob
), and note that it has the same shape. The log_prob
function "consumes" the event shape of x
.
The log_prob
of a transformed distribution is something like
nvp.distribution.log_prob(nvp.bijector.inverse(x)) - nvp.bijector.inverse_log_det_jacobian(x)
(I may have the sign swapped.)
Namely, it is the sum of the underlying distribution's log_prob
applied to the samples pulled back through the bijective transformation plus a correction term to account for the (local, at x
) change in volume induced by the bijective transformation.