Search code examples
tensorflowprobability

How to use MCMC sampling on a custom function with TensorFlow Probability


I'm just starting on TensorFlow and I am unsure how I can sample from a custom probability distributions that is not easily expressible as a composition of the stock distributions.

How can I use a custom function for target_log_prob input for MCMC samplers?


Solution

  • You can pass any target_log_prob_fn to the tfp.mcmc.HamiltonianMonteCarlo TransitionKernel, as long as it computes a value proportional to your target density (and is differentiable with respect to its inputs). E.g.

    def target_log_prob_fn(x):
      return -.5 * x ** 2
    

    is a perfectly valid target log prob function. If you want to sample multiple chains in parallel you'll need to take care that your target is "batch-friendly". For example, if you need to reduce_sum over some part of the state (say for a multivariate distribution), be sure to be explicit about which axes you're summing over

    def target_log_prob_fn(x):
      return -.5 * tf.reduce_sum(x ** 2, axis=-1)  # don't sum over chains!
    
    ...
    
    tfp.mcmc.sample_chain(
        kernel,
        num_burnin_steps=100,
        num_results=100,
        current_state=tf.zeros(10, 5),   # 10 parallel chains of 5-D variables
    )
    

    HTH!