Tensorflow version is 2.17.1
Tensoflow probability version is 0.24.0
Example from the documentation https://www.tensorflow.org/probability/api_docs/python/tfp/layers/MixtureNormal?hl=en is the following:
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
tfpl = tfp.layers
tfk = tf.keras
tfkl = tf.keras.layers
# Load data -- graph of a [cardioid](https://en.wikipedia.org/wiki/Cardioid).
n = 2000
t = tfd.Uniform(low=-np.pi, high=np.pi).sample([n, 1])
r = 2 * (1 - tf.cos(t))
x = r * tf.sin(t) + tfd.Normal(loc=0., scale=0.1).sample([n, 1])
y = r * tf.cos(t) + tfd.Normal(loc=0., scale=0.1).sample([n, 1])
# Model the distribution of y given x with a Mixture Density Network.
event_shape = [1]
num_components = 5
params_size = tfpl.MixtureNormal.params_size(num_components, event_shape)
model = tfk.Sequential([
tfkl.Dense(12, activation='relu'),
tfkl.Dense(params_size, activation=None),
tfpl.MixtureNormal(num_components, event_shape)
])
# Fit.
batch_size = 100
model.compile(optimizer=tf.train.AdamOptimizer(learning_rate=0.02),
loss=lambda y, model: -model.log_prob(y))
model.fit(x, y,
batch_size=batch_size,
epochs=20,
steps_per_epoch=n // batch_size)
This ends up with the error
ValueError: Only instances of `keras.Layer` can be added to a Sequential model. Received: <tensorflow_probability.python.layers.distribution_layer.MixtureNormal object at 0x7c9076269a50> (of type <class 'tensorflow_probability.python.layers.distribution_layer.MixtureNormal'>)
Taking a look at the release notes of TensorFlow Probability
:
"NOTE: In TensorFlow 2.16+, tf.keras (and tf.initializers, tf.losses, and tf.optimizers) refers to Keras 3. TensorFlow Probability is not compatible with Keras 3 -- instead TFP is continuing to use Keras 2, which is now packaged as tf-keras and tf-keras-nightly and is imported as tf_keras. When using TensorFlow Probability with TensorFlow, you must explicitly install Keras 2 along with TensorFlow (or install tensorflow-probability[tf] or tfp-nightly[tf] to automatically install these dependencies.)"
Try and follow the instructions above.