Search code examples
pythonpython-3.xtensorflow-probability

Specifying a DirichletMultinomial in tensorflow probability


This is probably quite basic, but I can't figure it out -- I have a 100x5 matrix y that is generated from a Dirichlet-Multinomial and I want to infer the parameters gamma using tensorflow probability. Below is the model I implemented (for simplicity I'm assuming that gamma is the same for all 5 classes for now):

def dirichlet_multinomial_model(S, p, N, tau):
    gamma = ed.Gamma(2.0, 3.0, name='gamma')
    y = ed.DirichletMultinomial(500, tf.ones(5)*gamma, name='y')
    return y

log_joint = ed.make_log_joint_fn(dirichlet_multinomial_model)

def target_log_prob_fn(gamma):
  return log_joint(
     S=S, p=p, N=N, tau=tau,
     gamma=gamma,
     y=y)

When I try to sample from this using HMC, I get the following error:

ValueError: Incompatible shape for initialization argument 'value'. Expected (5,), got (100, 5).

So specifying a length-5 vector of gammas seems to have resulted in the program expecting my data to be of shape 5x1. I cannot work out how to specify the model correctly -- any pointers would be appreciated.


Solution

  • As hinted in my comment, the fix here is to use sample_shape=[100,] instead of sample_shape=[100, 5]. We have 3 notions of shape in the TF Distributions library (which Edward wraps): sample shape, batch shape, and event shape.

    The event shape describes the shape of a single draw from the distribution. For example, a multivariate normal distribution in 5 dimensions has event_shape=[5,]

    The batch shape describes independent, non-identically distributed draws; a "batch" of distributions. E.g., a Normal(loc=[1., 2., 3], scale=1.) has batch_shape 3 because of the 3 values passed to the loc parameter.

    The sample shape describes IID draws from a batch of distributions. The resulting sampled Tensor has shape S + B + E where S, B, and E are the sample, batch, and event shapes, respectively.

    In your example, the DirichletMultinomial has a concentration parameter with shape [5,]. This corresponds to the event shape of the distribution: every draw from this distribution will be a collection of 5 integers adding up to total_count. When you sample the distribution 100 times, you do indeed get a result with shape=[100, 5] but the 5 is implicit in the event shape of the distribution -- you're only drawing 100 samples, hence sample_shape=[100,]

    Much of the above text is lifted from this great notebook, which has much more detail on TF Distribution shapes.

    Hope this helps clarify things! Happy sampling! ^_^