Search code examples
pythonjax

Understanding how JAX's tracer vs static work


I'm new to JAX and trying to write a simple code using JAX where at some point I need to use a scipy method. Then I want to take its derivative.

The code doesn't run and gives me error. The following is the code and the error. I read a the documentation of JAX a couple of times but couldn't Figure out what to do to write the code correctly

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import numpy as np
import scipy
key = random.PRNGKey(1)

size = 3
x = random.uniform(key, (size, size), dtype=jnp.float32)

def error_func(x):
    dists = scipy.spatial.distance.cdist(x, x, metric='euclidean')
    error = jnp.sum(jnp.array(dists))
    return error

error_diff = grad(error_func)

print(error_func(x))
print(error_diff(x))

And I get the followig error:

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
3.2158318
Traceback (most recent call last):
  File "/mnt/d/OneDrive - UW-Madison/opttest/jaxtest6.py", line 26, in <module>
    print(error_diff(x))
  File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/api.py", line 646, in grad_f
    _, g = value_and_grad_f(*args, **kwargs)
  File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/api.py", line 722, in value_and_grad_f
    ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
  File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/api.py", line 2179, in _vjp
    out_primal, out_vjp = ad.vjp(
  File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/interpreters/ad.py", line 139, in vjp
    out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
  File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/interpreters/ad.py", line 128, in linearize
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
  File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/interpreters/partial_eval.py", line 777, in trace_to_jaxpr_nounits
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/mnt/d/OneDrive - UW-Madison/opttest/jaxtest6.py", line 15, in error_func
    dists = scipy.spatial.distance.cdist(x, x, metric='euclidean')
  File "/home/sattarian/.local/lib/python3.9/site-packages/scipy/spatial/distance.py", line 2909, in cdist
    XA = np.asarray(XA)
  File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/core.py", line 598, in __array__
    raise TracerArrayConversionError(self)
jax._src.traceback_util.UnfilteredStackTrace: jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ConcreteArray([[0.7551559  0.3129729  0.12388372]
 [0.548188   0.8851279  0.30576992]
 [0.82008433 0.95633745 0.3566252 ]], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([[0.7551559 , 0.3129729 , 0.12388372],
       [0.548188  , 0.8851279 , 0.30576992],
       [0.82008433, 0.95633745, 0.3566252 ]], dtype=float32)
  tangent = Traced<ShapedArray(float32[3,3])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[3,3]), None)
    recipe = LambdaBinding()
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/mnt/d/OneDrive - UW-Madison/opttest/jaxtest6.py", line 26, in <module>
    print(error_diff(x))
  File "/mnt/d/OneDrive - UW-Madison/opttest/jaxtest6.py", line 15, in error_func
    dists = scipy.spatial.distance.cdist(x, x, metric='euclidean')
  File "/home/sattarian/.local/lib/python3.9/site-packages/scipy/spatial/distance.py", line 2909, in cdist
    XA = np.asarray(XA)
jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ConcreteArray([[0.7551559  0.3129729  0.12388372]
 [0.548188   0.8851279  0.30576992]
 [0.82008433 0.95633745 0.3566252 ]], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([[0.7551559 , 0.3129729 , 0.12388372],
       [0.548188  , 0.8851279 , 0.30576992],
       [0.82008433, 0.95633745, 0.3566252 ]], dtype=float32)
  tangent = Traced<ShapedArray(float32[3,3])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[3,3]), None)
    recipe = LambdaBinding()
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

Solution

  • JAX transformations only work on JAX functions, not numpy or scipy functions (this is discussed briefly at the link shown in the error message above) If you want to use grad and other JAX transformations, you need to write your logic using JAX operations, not operations from numpy, scipy, or other non-JAX-compatible libraries.

    JAX does not currently include any wrappers of scipy.spatial.distance (though there are some in progress, see #16147), so the best option would be to write the code yourself. Fortunately, cdist is pretty straightforward:

    def cdist(x, y, metric='euclidean'):
      assert x.ndim == y.ndim == 2
      if metric != 'euclidean':
        raise NotImplementedError(f"{metric=}")
      return jnp.sqrt(jnp.sum((x[:, None, :] - y[None, :, :]) ** 2, axis=-1))
    
    def error_func(x):
        dists = cdist(x, x, metric='euclidean')
        error = jnp.sum(dists)
        return error
    
    error_diff = grad(error_func)
    
    print(error_func(x))
    # 3.2158318
    
    print(error_diff(x))
    # [[nan nan nan]
    #  [nan nan nan]
    #  [nan nan nan]]
    

    You'll notice that the gradient is everywhere nan. This is the expected result due to the fact that grad(jnp.sqrt)(0.0) diverges (returns infinity), and 0.0 * inf by definition is nan.