Search code examples
probability-distributiontensorflow-probability

Constructing discrete table-based CPDs in tensorflow-probablity?


I'm trying to construct the simplest example of Bayesian network with several discrete random variables and conditional probabilities (the "Student Network" from Koller's book, see 1)

Although a bit unwieldy, I managed to build this network using pymc3. Especially, creating the CPDs is not that straightforward in pymc3, see the snippet below:

import pymc3 as pm

...

with pm.Model() as basic_model:
    # parameters for categorical are indexed as [0, 1, 2, ...]
    difficulty = pm.Categorical(name='difficulty', p=[0.6, 0.4])

    intelligence = pm.Categorical(name='intelligence', p=[0.7, 0.3])

    grade = pm.Categorical(name='grade',
        p=pm.math.switch(
            theano.tensor.eq(intelligence, 0),
                pm.math.switch(
                    theano.tensor.eq(difficulty, 0),
                        [0.3, 0.4, 0.3],  # I=0, D=0
                        [0.05, 0.25, 0.7]   # I=0, D=1
                    ),
                    pm.math.switch(
                        theano.tensor.eq(difficulty, 0),
                            [0.9, 0.08, 0.02],  # I=1, D=0
                            [0.5, 0.3, 0.2]  # I=1, D=1
                    )
            )
        )

    letter = pm.Categorical(name='letter', p=pm.math.switch(
    ...

But I have no idea how to build this network using tensoflow-probability (versions: tfp-nightly==0.7.0.dev20190517, tf-nightly-2.0-preview==2.0.0.dev20190517)

For the unconditioned binary variables, one can use categorical distribution, such as

from tensorflow_probability import distributions as tfd
from tensorflow_probability import edward2 as ed

difficulty = ed.RandomVariable(
                 tfd.Categorical(
                     probs=[0.6, 0.4],
                     name='difficulty'
                 )
             )

But how to construct the CPDs?

There are few classes/methods in tensorflow-probability that might be relevant (in tensorflow_probability/python/distributions/deterministic.py or the deprecated ConditionalDistribution) but the documentation is rather sparse (one needs deep understanding of tfp).

--- Updated question ---

Chris' answer is a good starting point. However, things are still a bit unclear even for a very simple two-variable model.

This works nicely:

jdn = tfd.JointDistributionNamed(dict(
    dist_x=tfd.Categorical([0.2, 0.8], validate_args=True),
    dist_y=lambda dist_x: tfd.Bernoulli(probs=tf.gather([0.1, 0.9], indices=dist_x), validate_args=True)
))
print(jdn.sample(10))

but this one fails

jdn = tfd.JointDistributionNamed(dict(
    dist_x=tfd.Categorical([0.2, 0.8], validate_args=True),
    dist_y=lambda dist_x: tfd.Categorical(probs=tf.gather_nd([[0.1, 0.9], [0.5, 0.5]], indices=[dist_x]))
))
print(jdn.sample(10))

(I'm trying to model categorical explicitly in the second example just for learning purposes)

-- Update: solved ---

Obviously, the last example wrongly used tf.gather_nd instead of tf.gather as we only wanted to select the first or the second row based on the dist_x outome. This code works now:

jdn = tfd.JointDistributionNamed(dict(
    dist_x=tfd.Categorical([0.2, 0.8], validate_args=True),
    dist_y=lambda dist_x: tfd.Categorical(probs=tf.gather([[0.1, 0.9], [0.5, 0.5]], indices=[dist_x]))
))
print(jdn.sample(10))

Solution

  • The tricky thing about this, and presumably the reason it's subtler than expected in PyMC, is -- as with almost everything in vectorized programming -- handling shapes.

    In TF/TFP, the (IMO) nicest way to solve this is with one of the new TFP JointDistribution{Sequential,Named,Coroutine} classes. These let you naturally represent hierarchical PGM models, and then sample from them, evaluate log probs, etc.

    I whipped up a colab notebook demoing all 3 approaches, for the full student network: https://colab.research.google.com/drive/1D2VZ3OE6tp5pHTsnOAf_7nZZZ74GTeex

    Note the crucial use of tf.gather and tf.gather_nd to manage the vectorization of the various binary and categorical switching.

    Have a look and let me know if you have any questions!