Search code examples
pythonjaxautomatic-differentiation

Obtaining zeros in this derivative in Jax


Implementing a jacobian for a polar to cartesian coordinates, I obtain an array of zeros in Jax, which it can't be

theta = np.pi/4
r = 4.0
    
var = np.array([r, theta])

x = var[0]*jnp.cos(var[1])
y = var[0]*jnp.sin(var[1])

def f(var):
    return np.array([x, y])
    
jac = jax.jacobian(f)(var)
jac

#DeviceArray([[0., 0.],
#             [0., 0.]], dtype=float32)

What am I missing?


Solution

  • Your function has no dependence on var because x, y are defined outside the function.

    This would give the desired output instead:

    theta = np.pi/4
    r = 4.0
        
    var = np.array([r, theta])
    
    def f(var):
        x = var[0]*jnp.cos(var[1])
        y = var[0]*jnp.sin(var[1])
        return jnp.array([x, y])
        
    jac = jax.jacobian(f)(var)
    jac
    

    Note that you need to return a jax numpy array rather than a numpy array as well.