As said in the title, I am trying to create a mixture of multivariate normal distributions using tensorflow probability package.
In my original project, am feeding the weights of the categorical, the loc and the variance from the output of a neural network. However when creating the graph, I get the following error:
components[0] batch shape must be compatible with cat shape and other component batch shapes
I recreated the same problem using placeholders:
import tensorflow as tf
import tensorflow_probability as tfp # dist= tfp.distributions
tf.compat.v1.disable_eager_execution()
sess = tf.compat.v1.InteractiveSession()
l1 = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, 2], name='observations_1')
l2 = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, 2], name='observations_2')
log_std = tf.compat.v1.get_variable('log_std', [1, 2], dtype=tf.float32,
initializer=tf.constant_initializer(1.0),
trainable=True)
mix = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None,1], name='weights')
cat = tfp.distributions.Categorical(probs=[mix, 1.-mix])
components = [
tfp.distributions.MultivariateNormalDiag(loc=l1, scale_diag=tf.exp(log_std)),
tfp.distributions.MultivariateNormalDiag(loc=l2, scale_diag=tf.exp(log_std)),
]
bimix_gauss = tfp.distributions.Mixture(
cat=cat,
components=components)
So, my question is, what am I doing wrong? I looked into the error and it seems tensorshape_util.is_compatible_with
is what raises the error but I don't see why.
Thanks!
It seems you provided a mis-shaped input to tfp.distributions.Categorical
. It's probs
parameter should be of shape [batch_size, cat_size]
while the one you provide is rather [cat_size, batch_size, 1]
. So maybe try to parametrize probs
with tf.concat([mix, 1-mix], 1)
.
There may also be a problem with yourlog_std
which doesn't have the same shape as l1
and l2
. In case MultivariateNormalDiag
doesn't properly broadcast it, try to specify it's shape as (None, 2)
or to tile it so that it's first dimension corresponds to that of your location parameters.