Search code examples
tensorflowtensorflow-probability

Having trouble understanding how tensorflow probability Bijectors 'RealNVP' 'log_prob works


Here's the code

tfd = tfp.distributions
tfb = tfp.bijectors

# A common choice for a normalizing flow is to use a Gaussian for the base
# distribution. (However, any continuous distribution would work.) E.g.,
nvp = tfd.TransformedDistribution(
    distribution=tfd.MultivariateNormalDiag(loc=[0., 0., 0.]),
    bijector=tfb.RealNVP(
        num_masked=2,
        shift_and_log_scale_fn=tfb.real_nvp_default_template(
            hidden_layers=[512, 512])))

x = nvp.sample((32,32))

x = nvp.sample((32,32)) gives me a tensor with 32x32x3shape . But when throwing the x into nvp.log_prob(x), I get a 32x32shape tensor. I was expecting a (1,)like tensor since I want to get log_prob of this 32,32,3 tensor.

So the problem is, how to tinker the code above to calculate log_prob of a 32x32x3-shape tensor?


Solution

  • RNVP transforms vector-valued distributions (i.e. MVNDiag in your case above). You can try nvp.distribution.log_prob(x) (apply the underlying distribution's log_prob), and note that it has the same shape. The log_prob function "consumes" the event shape of x.

    The log_prob of a transformed distribution is something like

    nvp.distribution.log_prob(nvp.bijector.inverse(x)) - nvp.bijector.inverse_log_det_jacobian(x) (I may have the sign swapped.)

    Namely, it is the sum of the underlying distribution's log_prob applied to the samples pulled back through the bijective transformation plus a correction term to account for the (local, at x) change in volume induced by the bijective transformation.