Search code examples
pythontensorflowprobability-densitytensorflow-probability

How should tensorflow_probability distributions be used for multi-dimensional spaces?


I would like to create a multi-dimensional gaussian probability density function (let's say a 2D gaussian like in the figure below) with tensorflow.

enter image description here

For 1D, it works like a charm:

d = tfp.distributions.Normal(loc=5.0, scale=3.0)
x = d.prob(tf.range(0,10, dtype=tf.float32))

But for higher dimension, I get InvalidArgumentError: Incompatible shapes error using Normal or MultivariateNormalDiag distributions... What do I miss? How should the prob method be used to output the probability density function on a multi dimensional tensor?


Solution

  • If I understood correctly, you can do something like:

    mu = [0,0]
    cov = [[1,0],
           [0,1]]
    mv_normal = np.random.multivariate_normal(mu, cov, size=1000)
    mv_normal_mean = np.mean(mv_normal , axis=0)
    mv_normal_cov = np.cov(mv_normal , rowvar=0)
    mv_normal_diag = np.diag(mv_normal_cov)
    mv_normal_stddev = np.sqrt(mv_normal_diag)
    

    mv_normal is just like:

    mv_normal
    array([[-1.73476374,  0.17578855],
           [ 0.11866498, -0.66417069],
           [ 1.52000069, -1.3004096 ],
           ...,
           [-1.37625595, -0.46864374],
           [ 0.81659449,  0.70524036],
           [ 1.12183633,  0.14196896]])
    

    mv_normal_mean and mv_normal_cov etc are just arrays here. They will be used to create:

     mvn = tfd.MultivariateNormalDiag(
     loc=mv_normal_mean,
     scale_diag=mv_normal_stddev)
    

    Values can be seen as:

    mvn_mean
    array([-0.03976356,  0.07387231])
    

    mv_normal_cov
    array([[ 1.04138867, -0.00877481],
           [-0.00877481,  0.97736496]])
    

    And you can use contour plot for plotting.

    x1, x2 = np.meshgrid(mv_normal[:,0], mv_normal[:,1])
    data = np.stack((x1.flatten(), x2.flatten()), axis=1)
    prob = mvn.prob(data).numpy()
    plt.figure(figsize = (12,9))
    ax = plt.axes(projection='3d')
    ax.plot_surface(x1, x2, prob.reshape(x1.shape), cmap = 'Blues')
    plt.show()
    

    That will produce as follows: enter image description here