Search code examples
tensorflowtensorflow-probability

Unable to get log_prob from a TransformedDistribution in tensorflow


I was following along a tutorial regarding Transformed Distribution. We could specify the batch_shape to [2] and event_shape to [4] in previous version of tensorflow TransformedDistribution but we can't now. I am wondering if how can we make the last code work without going back to the previous version of Tensorflow?

The error raised was:

ValueError: `event_ndims must be at least 0. Saw: 1

Code:

# Parameters
n = 10000

loc = 0

scale = 0.5

# Normal distribution
normal = tfd.Normal(loc=loc, scale=scale)

# Set a scaling lower triangular matrix
tril = tf.random.normal((2,4,4))

scale_low_tri = tf.linalg.LinearOperatorLowerTriangular(tril)

# Define scale linear operator
scale_lin_op = tfb.ScaleMatvecLinearOperator(scale_low_tri)

# Define scale linear operator transformed distribution with a batch and event shape
mvn = tfd.TransformedDistribution(distribution=normal, bijector=scale_lin_op)

xn = normal.sample((n,2,4))

mvn2.log_prob(xn)

Solution

  • I believe in version 0.12 you need to use tfd.Sample:

    mvn = tfd.TransformedDistribution(distribution=tfd.Sample(
                                      tfd.Normal(loc=[loc, loc], scale=[scale, scale]),
                                      sample_shape=[4]), # --> event shape
                                      bijector=scale_lin_op)
    
    mvn.log_prob(xn)
    

    Output:

    <tf.Tensor: shape=(10000, 2), dtype=float32, numpy=
    array([[-4.5943561e+00, -6.5238861e+01],
           [-8.6548815e+00, -2.1378198e+05],
           [-3.1688419e+01, -3.3126004e+04],
           ...,
           [-1.8664089e+00, -3.8012810e+03],
           [-1.8821844e+01, -1.3414998e+04],
           [-2.3645339e+00, -2.4178730e+05]], dtype=float32)>