Search code examples
tensorflowtensorflow-probability

Probability of batched mixture distribution in TensorFlow Probability


TFP distributions should batch capable out of the box. However, I am facing a problem with batched mixture distribution. Hereby is a toy example (eager execution is used):

tfd = tfp.distributions
mix = np.array([[0.6, 0.4],[0.3, 0.7]] )
bimix_gauss = tfd.Mixture(
  cat=tfd.Categorical(probs=mix),
  components=[
    tfd.Normal(loc=[-1.0, -2.0], scale=[0.1, 0.1]),
    tfd.Normal(loc=[+1.0, +2.0], scale=[0.5, 0.5]),
])

print(bimix_gauss.sample())
print(bimix_gauss.prob(0.0))

Basically, it is just baching of the default example: https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/Mixture

Sampling works fine, but prbability of this distribution returns an error: InvalidArgumentError: cannot compute Add as input #1(zero-based) was expected to be a double tensor but is a float tensor [Op:Add] name: Mixture/prob/add/

Any guesses, what I am doing wrong?

PS. The same example with batch Gaussian distribution works fine.


Solution

  • The issue is that numpy defaults to float64 but TFP follows the TF convention of defaulting to float32. So your normal Distributions, whose parameters are bare python lists, are recast as tf.Tensors in the constructor of Normal as float32 tensors ultimately leading to a type error. You can fix either by forcing the np array to be float 32, or maybe more simply by just passing the mixture values as lists instead of ndarrays.