Implementing a jacobian for a polar to cartesian coordinates, I obtain an array of zeros in Jax, which it can't be
theta = np.pi/4
r = 4.0
var = np.array([r, theta])
x = var[0]*jnp.cos(var[1])
y = var[0]*jnp.sin(var[1])
def f(var):
return np.array([x, y])
jac = jax.jacobian(f)(var)
jac
#DeviceArray([[0., 0.],
# [0., 0.]], dtype=float32)
What am I missing?
Your function has no dependence on var because x, y are defined outside the function.
This would give the desired output instead:
theta = np.pi/4
r = 4.0
var = np.array([r, theta])
def f(var):
x = var[0]*jnp.cos(var[1])
y = var[0]*jnp.sin(var[1])
return jnp.array([x, y])
jac = jax.jacobian(f)(var)
jac
Note that you need to return a jax numpy array rather than a numpy array as well.