Search code examples
pythontensorflowtensorflow-probability

Passing tensorflow-probability distributions as bijector parameters


I want to create a TransformedDistribution whose transforming bijector (a Chain of bijectors) has some of its components parametrized as distributions themselves, with the objective of having a different result everytime I transform some tensor through the bijector (because the parameters of the bijector will be sampled each time as well).

Let's take a very simple example to illustrate what I say (without Chains):

import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
tfb = tfp.bijectors

base_distribution = tfd.Normal(loc=0, scale=1.5)
x = base_distribution.sample(10) # Sample 10 times, to get fixed values
# Define the distribution for the parameter of the bijector
scale_distribution = tfd.Uniform(low=2/3-0.5, high=2/3+0.5)
# This is how I wish it was, but fails
bijector = tfb.Scale(scale=scale_distribution)
# ValueError: TypeError: object of type 'Uniform' has no len()

Elaborating on the example a bit more:

transformed_dist = tfd.TransformedDistribution(
    base_distribution,
    bijector)

transformed_dist.sample() # get one sample of the original distribution, 
                          # scaled with some scale factor drawn from the
                          # uniform distribution.

I know there is a way to build hierarchical distribution models with a JointDistribution, which lets me use one (or more) distributions as the parameters of another distribution then sample from the joint distribution.

But I have not found an equivalent way of randomly parametrize a bijector. An approach that "works" but is a bit cumbersome is:

# Works, but the sample is generated and becomes fixed when defining the bijector
bijector = tfb.Scale(scale=tf.random.uniform(minval=2/3-0.5, maxval=2/3+0.5,))

transformed_dist = tfd.TransformedDistribution(
    base_distribution,
    bijector)

transformed_dist.sample() # Now the distribution will always sample from
                          # a Normal what is Scaled by a deterministic 
                          # parameter, that was randomly generated at 
                          # definition time.

As I explain in the code this would require me to rerun the whole code block every time I want the parameter to be different.

The reason I want to do it like this is that I want to automatically generate samples in such a way that a certain rotation is randomized, i.e. get a different distribution every time I sample.

Note: I'm using tensorflow >2.1 with eager execution.


Solution

  • I still think a JointDistribution is the way to go here, doing something like

    joint = tfd.JointDistributionSequential([
       tfd.Uniform(...),
       lambda unif: tfb.MyBijector(param=unif)(tfd.Normal(0., 1.))
    ])
    

    to feed into the bijector param.

    However, you could also cobble something together with tfp.util.DeferredTensor. This is an object that can be used anywhere a Tensor can be, but whose value is taken from the execution of a given function. E.g.

    rand_dt = tfp.util.DeferredTensor(
        pretransformed_input=1.,  # value doesn't matter here
        transform_fn=lambda _: tf.random.uniform(minval=1., maxval=2.)
    )
    
    td = tfb.Scale(scale=rand_dt)(tfd.Normal(0., 1.))
    
    for i in range(5):
      print(td.sample())
      # will print a sample with a different random scale on each call
    

    Here's a colab, with a slightly modified version of the above example, to illustrate what's happening: https://colab.research.google.com/drive/1akjX6a1W-RJoUjsy0hVOrrRAQiBIYjY-

    HTH!

    UPDATE:

    Upon further reflection, I should actually give some pretty strong caution around the second pattern. Unfortunately it's very hard to guarantee that a single Tensor-ified value of a DeferredTensor will be used consistently -- even in the context of, say, a single Distribution method invocation -- and so doing anything that has side effects can be dangerous. For example:

    rand_dt = tfp.util.DeferredTensor(
        pretransformed_input=1.,  # value doesn't matter here
        transform_fn=lambda _: tf.random.uniform(minval=1., maxval=2.)
    )
    
    td = tfb.Scale(scale=rand_dt)(tfd.Normal(0., 1.))
    
    sample = td.sample()  # sample from TD with some random scale
    print(td.log_prob(sample))  # new scale; wrong log_prob!
    

    One workaround would be to use the stateless TF sampling API, either directly or by passing list-valued or Tensor-valued seeds to TFP (see more on stateless sampling semantics here), and explicitly manage the seeds. This may make the above example much uglier and hard to work with though (you'd probably need to have a Variable floating around that acts as the seed input and gets assigned each time you want a new sample)

    Your best bet is probably to use the JointDistribution approach!