Search code examples
pythontensorflowtensorflow-probability

Tensorflow probability: InvalidArgumentError: required broadcastable shapes


My data has channels_first format. When I use tensorflow probability layers I get the following error:

Here is an example where the input shape is [1,28,28] and the reproducible code: Gist (please make sure you are running the code on GPU.)

InvalidArgumentError:  required broadcastable shapes
     [[node gradient_tape/lambda/model_3_mixture_same_family_4_MixtureSameFamily_MixtureSameFamily/log_prob/model_3_mixture_same_family_4_MixtureSameFamily_independent_normal_4_IndependentNormal_Independentmodel_3_mixture_same_family_4_MixtureSameFamily_independent_normal_4_IndependentNormal_Normal/log_prob/model_3_mixture_same_family_4_MixtureSameFamily_independent_normal_4_IndependentNormal_Normal/log_prob/truediv/RealDiv_1 (defined at <ipython-input-22-243a182981d9>:9) ]] [Op:__inference_train_function_7663]

Errors may have originated from an input operation.
Input Source operations connected to node gradient_tape/lambda/model_3_mixture_same_family_4_MixtureSameFamily_MixtureSameFamily/log_prob/model_3_mixture_same_family_4_MixtureSameFamily_independent_normal_4_IndependentNormal_Independentmodel_3_mixture_same_family_4_MixtureSameFamily_independent_normal_4_IndependentNormal_Normal/log_prob/model_3_mixture_same_family_4_MixtureSameFamily_independent_normal_4_IndependentNormal_Normal/log_prob/truediv/RealDiv_1:
 model_3/mixture_same_family_4/MixtureSameFamily/independent_normal_4/IndependentNormal/Softplus (defined at /usr/local/lib/python3.7/dist-packages/tensorflow_probability/python/layers/distribution_layer.py:988)

Function call stack:
train_function

I am not sure how to change the source code so that it works with channels first input shape. Can someone help me with this?


Solution

  • Your preprocess function is returning image, image instead of image, sample['label']. If you change this, it should work!

    I think you can then drop the K.cast in your loss as well.

    Update: actually when i run this i get nan's in the loss. Probably something else is wrong. But at least it gets past the shape error! 🤷‍♂️