Search code examples
pythontensorflow2.0mixture-modeltensorflow-probability

Mixture of multivariate gaussian distribution tensorflow probability


As said in the title, I am trying to create a mixture of multivariate normal distributions using tensorflow probability package.

In my original project, am feeding the weights of the categorical, the loc and the variance from the output of a neural network. However when creating the graph, I get the following error:

components[0] batch shape must be compatible with cat shape and other component batch shapes

I recreated the same problem using placeholders:

import tensorflow as tf
import tensorflow_probability as tfp # dist= tfp.distributions 

tf.compat.v1.disable_eager_execution()
sess = tf.compat.v1.InteractiveSession()

l1 = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, 2], name='observations_1')
l2 = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, 2], name='observations_2')

log_std = tf.compat.v1.get_variable('log_std', [1, 2], dtype=tf.float32,
                                          initializer=tf.constant_initializer(1.0),
                                          trainable=True)

mix = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None,1], name='weights')

cat = tfp.distributions.Categorical(probs=[mix, 1.-mix])
components = [
    tfp.distributions.MultivariateNormalDiag(loc=l1, scale_diag=tf.exp(log_std)),
    tfp.distributions.MultivariateNormalDiag(loc=l2, scale_diag=tf.exp(log_std)),
]

bimix_gauss = tfp.distributions.Mixture(
  cat=cat,
  components=components)

So, my question is, what am I doing wrong? I looked into the error and it seems tensorshape_util.is_compatible_with is what raises the error but I don't see why.

Thanks!


Solution

  • It seems you provided a mis-shaped input to tfp.distributions.Categorical. It's probs parameter should be of shape [batch_size, cat_size] while the one you provide is rather [cat_size, batch_size, 1]. So maybe try to parametrize probs with tf.concat([mix, 1-mix], 1).

    There may also be a problem with yourlog_std which doesn't have the same shape as l1and l2. In case MultivariateNormalDiag doesn't properly broadcast it, try to specify it's shape as (None, 2) or to tile it so that it's first dimension corresponds to that of your location parameters.