Search code examples
pythontensorflowcdf

CDF of MultivariateNormalDiag in tensorflow


I can run this example from here:

mu = [1, 2, 3.]
diag_stdev = [4, 5, 6.]
dist = tf.contrib.distributions.MultivariateNormalDiag(mu, diag_stdev)
dist.pdf([-1., 0, 1])

but when I substitute the last line for dist.cdf([-1., 0, 1]) I get a not implemented error:

NotImplementedError: log_cdf is not implemented

Can anybody suggest a workaround for the time being at least?


Solution

  • Based on the solutions in here and here, I've implemented the following solution:

    import tensorflow as tf
    import numpy as np
    from scipy.stats import mvn
    
    def py_func(func, inp, Tout, stateful=True, name=None, grad=None):
    
        # Need to generate a unique name to avoid duplicates:
        rnd_name = 'PyFuncGrad' + str(np.random.randint(0, 1E+10))
    
        tf.RegisterGradient(rnd_name)(grad)
        g = tf.get_default_graph()
        with g.gradient_override_map({"PyFunc": rnd_name}):
            return tf.py_func(func, inp, Tout, stateful=stateful, name=name)
    
    def np_cdf(mean, diag_sigma, value, name=None):
      low = np.array([-30, -30])
      cdf = list()
      for variables in zip(value, mean, diag_sigma):
        S = np.diag(variables[2])
        p, _ = mvn.mvnun(low,variables[0],variables[1],S)
        cdf.append(p)
    
      cdfs = np.asarray(cdf, dtype=np.float32).reshape([-1,1])
      return cdfs
    
    def cdf_gradient(op, grad): 
      mu = op.inputs[0]
      diag_sigma = op.inputs[1]
      value = op.inputs[2]
      dist = tf.contrib.distributions.MultivariateNormalDiag(mu, diag_sigma)
      pdf = dist.pdf(value)
      dc_dv = tf.inv(diag_sigma) * pdf
      dc_dm = -1 * dc_dv
      dc_ds = tf.div(value-mu,tf.square(diag_sigma)+1e-6) * pdf
      return grad * dc_dm, grad * dc_ds, grad * dc_dv
    
    def tf_cdf(mean, diag_sigma, value, name=None):
    
        with tf.name_scope(name, "MyCDF", [mean, diag_sigma, value]) as name:
            cdf = py_func(np_cdf,
                            [mean, diag_sigma, value],
                            [tf.float32],
                            name=name,
                            grad=cdf_gradient)  # <-- here's the call to the gradient
            return cdf[0]