Search code examples
tensorflowmachine-learningautoencodertensorflow-probability

Forming conditional distributions in TensorFlow probability


I am using Tensorflow Probability to build a VAE which includes image pixels as well as some other variables. The output of the VAE:

tfp.distributions.Independent(tfp.distributions.Bernoulli(logits), 2, name="decoder-dist")

I am trying to understand how to form other conditional distributions based on this which I can use with the inference methods (MCMC or VI). Say the output above was P(A,B,C | Z), how would I take that distribution to form a posterior P(A|B, C, Z) that I could perform inference on? I have been trying to read through the docs but I am having some trouble grasping them.


Solution

  • The answer to your question depends very much on the nature of the joint model within which you'd like to do the conditioning. Much has been written about the topic, and in short it's a very hard problem in general :). Without knowing a bit more about the particulars of your problem, it's near impossible to recommend a useful generic inference procedure. However, we do have a bunch of examples (scripts and jupyter/colab notebooks) in the TFP repo here: https://github.com/tensorflow/probability/tree/master/tensorflow_probability/examples

    In particular, there's

    • The Hierarchical Linear Model example, which is a sort of Rosetta stone showing how to do posterior inference using Hamiltonian Monte Carlo (an MCMC technique) in TFP, R, and Stan,

    • The Linear Mixed Effects Model example, showing how you might use VI to solve a standard LME problem,

    among many others. You can click the "Run in Google Colab" link at the top of any of these notebooks to open and run on them on https://colab.research.google.com.

    Please feel free, also, to reach out on to us via email at [email protected]. This is a public Google Group where users can engage with the team that builds TFP directly. If you provide us some more info there on what you'd like to do, we're happy to provide guidance on modeling and inference with TFP.

    Hope this is gives at least a start in the right direction!