Search code examples
pythonnumpytensorflowbatching

Tensorflow MultivariateNormalDiag tensor of shape (None, output_dim, output_dim, output_dim) given mu and sigma of shape (None, 3)


So I am trying to make a tensor of multivariate gaussians using MultivariateNormalDiag

I would like to supply two tensors of shape (None, 3) for the mu and sigma parameters like so

dist = tf.contrib.distributions.MultivariateNormalDiag(mu, sigma)

such that I can then supply a set of points, in this case

dim_range = [float(i) for i in range(0, max_size)]
points = [[a,b,c] for a in dim_range for b in dim_range for c in dim_range]

and retrieve a set of points with density normally distributed around mu, as

gauss_tensor = tf.reshape(
    dist.pdf(points), 
    shape=(None, output_dim, output_dim, output_dim)
)

for a single example, eg. mu and sigma have shape (3,) and output shape (output_dim, output_dim, output_dim), and if visualized 3-dimensionally we get this visualization

for output_dim = 16 and the mu and sigma chosen in a semi random way to show variances in each dimension. A full working example can be found here and an example of what i'm trying to achieve here [edit: for 1.0 onwards, dist.pdf(points) needs to be changed to dist.prob(points)]

However, if the same is tried for a batch of unknown size, such that the output would be (None, output_dim, output_dim, output_dim), everything crashes with varying, non-consistent error messages given different approaches for solving the issue.

Does anyone know how to accomplish this for varying batch sizes where each batch element has a corresponding mu and sigma in a batch of mus and a batch of sigmas?

Thanks in advance :)

p.s. this is using tensorflow 0.12 but if there are fixes in 1.* i will consider rebuilding tensorflow


Solution

  • As a friend pointed out, the functionality of MultivariateNormalDiag is different in 1.2. Upgrading Tensorflow and re-aligning some things sorted the issue.

    mu_placeholder = tf.placeholder(
        dtype=tf.float32,
        shape=(None, None, 3),
        name='mu-tensor')
    

    [edit: for the mus/sigmas (None, 1, 3) also gives the correct result]

    mu_placeholder = tf.placeholder(
        dtype=tf.float32,
        shape=(None, 1, 3),
        name='mu-tensor')
    

    A working example is here