I'm trying to use JAX to generate samples from multivariate normal distribution using:
import jax
import jax.numpy as jnp
import numpy as np
key = random.PRNGKey(0)
cov = np.array([[1.2, 0.4], [0.4, 1.0]])
mean = np.array([3,-1])
x1,x2 = jax.random.multivariate_normal(key, mean, cov, 5000).T
However when I run the code I get the following error:
TypeError Traceback (most recent call last)
<ipython-input-25-1397bf923fa4> in <module>()
2 cov = np.array([[1.2, 0.4], [0.4, 1.0]])
3 mean = np.array([3,-1])
----> 4 x1,x2 = jax.random.multivariate_normal(key, mean, cov, 5000).T
1 frames
/usr/local/lib/python3.6/dist-packages/jax/core.py in canonicalize_shape(shape)
1159 "got {}.")
1160 if any(isinstance(x, Tracer) and isinstance(get_aval(x), ShapedArray)
-> 1161 and not isinstance(get_aval(x), ConcreteArray) for x in shape):
1162 msg += ("\nIf using `jit`, try using `static_argnums` or applying `jit` to "
1163 "smaller subfunctions.")
TypeError: 'int' object is not iterable
I'm not sure what the problem is since the same syntax works for the equivalent function in Numpy
In the jax.random
module, most shapes must explicitly be tuples. So instead of shape 5000
, use (5000,)
:
x1,x2 = jax.random.multivariate_normal(key, mean, cov, (5000,)).T