Search code examples
tensorflow-probability

Tensorflow-probability transform event shape of JointDistribution


I would like to create a distribution for n categorical variables C_1,.., C_n whose event shape is n. Using JointDistributionSequentialAutoBatched the event dimension is a list [[],..,[]]. For example for n=2

import tensorflow_probability.python.distributions as tfd

probs = [
    [0.8, 0.2], # C_1 in {0,1}
    [0.3, 0.3, 0.4] # C_2 in {0,1,2}
    ]

D = tfd.JointDistributionSequentialAutoBatched([tfd.Categorical(probs=p) for p in probs])

>>> D
<tfp.distributions.JointDistributionSequentialAutoBatched 'JointDistributionSequentialAutoBatched' batch_shape=[] event_shape=[[], []] dtype=[int32, int32]>

How do I reshape it to get event shape [2]?


Solution

  • A few different approaches could work here:

    1. Create a batch of Categorical distributions and then use tfd.Independent to reinterpret the batch dimension as the event:
    vector_dist = tfd.Independent(
      tfd.Categorical(
        probs=[
          [0.8, 0.2, 0.0],  # C_1 in {0,1}
          [0.3, 0.3, 0.4]  # C_2 in {0,1,2}
        ]),
      reinterpreted_batch_ndims=1)
    

    Here I added an extra zero to pad out probs so that both distributions can be represented by a single Categorical object.

    1. Use the Blockwise distribution, which stuffs its component distributions into a single vector (as opposed to the JointDistribution classes, which return them as separate values):
    vector_dist = tfd.Blockwise([tfd.Categorical(probs=p) for p in probs])
    
    1. The closest to a direct answer to your question is to apply the Split bijector, whose inverse is Concat, to the joint distribution:
    tfb = tfp.bijectors
    D = tfd.JointDistributionSequentialAutoBatched(
      [tfd.Categorical(probs=[p] for p in probs])
    vector_dist = tfb.Invert(tfb.Split(2))(D)
    

    Note that I had to awkwardly write probs=[p] instead of just probs=p. This is because the Concat bijector, like tf.concat, can't change the tensor rank of its argument---it can concatenate small vectors into a big vector, but not scalars into a vector---so we have to ensure that its inputs are themselves vectors. This could be avoided if TFP had a Stack bijector analogous to tf.stack / tf.unstack (it doesn't currently, but there's no reason this couldn't exist).