Search code examples
tensorflowtensorflow-probabilityprobabilistic-programming

Achieving `observe` behaviour in TensorFlow Probability


Consider the definition of the observe statement from Probabilistic Programming, as defined in [1]:

The observe statement blocks runs which do not satisfy the boolean expression E and does not permit those executions to happen.

Now, consider the following theoretical program:

def f():
    x ~ Normal(0, 1)
    observe(x > 0) # only allow samples x > 0
    return x

which should return values from the truncated Normal(0, 1) distribution.

Therefore, my question is: how can observe be achieved in TensorFlow Probability, or what’s its equivalent? Note that observe's argument should be any (symbolic) boolean expression E: (e.g. lambda x: x > 0).

NOTE: Sure, for the program above one can use the HalfNormal distribution, but I am using it for a practical example of observe.


[1] Gordon, Andrew D., et al. “Probabilistic programming.” Proceedings of the on Future of Software Engineering . 2014. 167-181.


Solution

  • The only way to achieve this in general is using a rejection sampler, which is expensive. And then you don't have a tractable density. In general TFP requires all our distributions to have a tractable density (i.e. dist.prob(x)). We do have an autodiff friendly TruncatedNormal, or as you note HalfNormal.

    If you wanted to implement something else it could be as simple as:

    class Rejection(tfd.Distribution):
      def __init__(self, underlying, condition, name=None):
        self._u = underlying
        self._c = condition
        super().__init__(dtype=underlying.dtype, 
                         name=name or f'rejection_{underlying}',
                         reparameterization_type=tfd.NOT_REPARAMETERIZED,
                         validate_args=underlying.validate_args,
                         allow_nan_stats=underlying.allow_nan_stats)
      def _batch_shape(self):
        return self._u.batch_shape
      def _batch_shape_tensor(self):
        return self._u.batch_shape_tensor()
      def _event_shape(self):
        return self._u.event_shape
      def _event_shape_tensor(self):
        return self._u.event_shape_tensor()
    
      def _sample_n(self, n, seed=None):
        return tf.while_loop(
            lambda samples: not tf.reduce_all(self._c(samples)),
            lambda samples: (tf.where(self._c(samples), samples, self._u.sample(n, seed=seed)),),
            (self._u.sample(n, seed=seed),))[0]
    
    d = Rejection(tfd.Normal(0,1), lambda x: x > -.3)
    s = d.sample(100).numpy()
    print(s.min())