Search code examples
pythontensorflowmontecarlotensorflow-probability

How to use tfp.density.Mixture with JointDistributionCoroutine


I'm trying to define a model function for MCMC. The idea is to have a mixture of two distributions controlled with a probability ratio. One of my attempts would look like this:

import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions

root = tfd.JointDistributionCoroutine.Root

def model_fn():
    rv_p     = yield root(tfd.Sample(tfd.Uniform(0.0,1.0),1))

    catprobs = tf.stack([rv_p, 1.-rv_p],0)
    rv_cat = tfd.Categorical(probs=catprobs)

    rv_norm1  = tfd.Sample(tfd.Normal(0.0,1.0),1)
    rv_norm2  = tfd.Sample(tfd.Normal(3.0,1.0),1)

    rv_mix = yield tfd.Mixture(cat=rv_cat,
                     components=[
                        rv_norm1,
                        rv_norm2,
                     ])

jd = tfd.JointDistributionCoroutine(model_fn)
jd.sample(2)

The code fails with:

ValueError: components[0] batch shape must be compatible with cat shape and other component batch shapes ((2, 2) vs ())

Could you give me an example of how to use Mixture distribution in a way that allows "any" shape of inputs?

I'm using tensorflow 2.4.1 and tensorflow_probability 0.12.1 with python 3.6


Solution

  • I figured it out. For reference here is a sample code:

    import os
    os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
    import tensorflow as tf
    import tensorflow_probability as tfp
    import matplotlib.pyplot as plt
    tfd = tfp.distributions
    tfb = tfp.bijectors
    
    import numpy as np
    from time import time
    
    numdata = 10000
    data = np.random.normal(0.0,1.0,numdata).astype(np.float32)
    data[int(numdata/2):] = 0.0
    _=plt.hist(data,30,density=True)
    
    root = tfd.JointDistributionCoroutine.Root
    def dist_fn(rv_p,rv_mu):
        rv_cat = tfd.Categorical(probs=tf.stack([rv_p, 1.-rv_p],-1))
        rv_norm  = tfd.Normal(rv_mu,1.0)
        rv_zero =  tfd.Deterministic(tf.zeros_like(rv_mu))
        
        rv_mix = tfd.Independent(
                    tfd.Mixture(cat=rv_cat,
                                components=[rv_norm,rv_zero]),
                    reinterpreted_batch_ndims=1)
        return rv_mix
    
    
    def model_fn():
        rv_p    = yield root(tfd.Sample(tfd.Uniform(0.0,1.0),1))
        rv_mu   = yield root(tfd.Sample(tfd.Uniform(-1.,1. ),1))
        
        rv_mix  = yield dist_fn(rv_p,rv_mu)
        
    jd = tfd.JointDistributionCoroutine(model_fn)
    unnormalized_posterior_log_prob = lambda *args: jd.log_prob(args + (data,))
    
    n_chains = 1
    
    p_init = [0.3]
    p_init = tf.cast(p_init,dtype=tf.float32)
    
    mu_init = 0.1
    mu_init = tf.stack([mu_init]*n_chains,axis=0)
    
    initial_chain_state = [
        p_init,
        mu_init,
    ]
    
    bijectors = [
        tfb.Sigmoid(),  # p
        tfb.Identity(),  # mu
    ]
    
    step_size = 0.01
    
    num_results = 50000
    num_burnin_steps = 50000
    
    
    kernel=tfp.mcmc.TransformedTransitionKernel(
        inner_kernel=tfp.mcmc.HamiltonianMonteCarlo(
        target_log_prob_fn=unnormalized_posterior_log_prob,
        num_leapfrog_steps=2,
        step_size=step_size,
        state_gradients_are_stopped=True),
        bijector=bijectors)
    
    kernel = tfp.mcmc.SimpleStepSizeAdaptation(
        inner_kernel=kernel, num_adaptation_steps=int(num_burnin_steps * 0.8))
    
    #XLA optim
    @tf.function(autograph=False, experimental_compile=True)
    def graph_sample_chain(*args, **kwargs):
      return tfp.mcmc.sample_chain(*args, **kwargs)
    
    
    st = time()
    trace,stats = graph_sample_chain(
          num_results=num_results,
          num_burnin_steps=num_burnin_steps,
          current_state=initial_chain_state,
          kernel=kernel)
    et = time()
    print(et-st)
    
    
    ptrace, mutrace = trace
    plt.subplot(121)
    _=plt.hist(ptrace.numpy(),100,density=True)
    plt.subplot(122)
    _=plt.hist(mutrace.numpy(),100,density=True)
    print(np.mean(ptrace),np.mean(mutrace))