I want to create a jittable function that outputs a distrax
distribution object. For instance:
import distrax
import jax
import jax.numpy as jnp
def f(x):
dist = distrax.Categorical(logits=jnp.sin(x))
return dist
jit_f = jax.jit(f)
a = jnp.array([1,2,3])
dist = jit_f(a)
Currently this code gives me the following error:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "F:\jax_env\lib\site-packages\jax\_src\traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "F:\jax_env\lib\site-packages\jax\_src\api.py", line 628, in cache_miss
out = tree_unflatten(out_pytree_def, out_flat)
File "F:\jax_env\lib\site-packages\jax\_src\tree_util.py", line 75, in tree_unflatten
return treedef.unflatten(leaves)
File "F:\jax_env\lib\site-packages\distrax\_src\utils\jittable.py", line 40, in tree_unflatten
obj = cls(*args, **kwargs)
File "F:\jax_env\lib\site-packages\distrax\_src\distributions\categorical.py", line 60, in __init__
self._logits = None if logits is None else math.normalize(logits=logits)
File "F:\jax_env\lib\site-packages\distrax\_src\utils\math.py", line 72, in normalize
return jax.nn.log_softmax(logits, axis=-1)
File "F:\jax_env\lib\site-packages\jax\_src\traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "F:\jax_env\lib\site-packages\jax\_src\api.py", line 618, in cache_miss
keep_unused=keep_unused))
File "F:\jax_env\lib\site-packages\jax\core.py", line 2031, in call_bind_with_continuation
top_trace = find_top_trace(args)
File "F:\jax_env\lib\site-packages\jax\core.py", line 1122, in find_top_trace
top_tracer._assert_live()
File "F:\jax_env\lib\site-packages\jax\interpreters\partial_eval.py", line 1486, in _assert_live
raise core.escaped_tracer_error(self, None)
jax._src.traceback_util.UnfilteredStackTrace: jax._src.errors.UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[3] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was f at <stdin>:1 traced for jit.
------------------------------
The leaked intermediate value was created on line <stdin>:2 (f).
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
<stdin>:1 (<module>)
<stdin>:2 (f)
------------------------------
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError
I thought that using dist = jax.block_until_ready(dist)
inside f
could fix the problem, but it doesn't.
This looks like the bug in distrax v0.1.2 reported in https://github.com/deepmind/distrax/issues/162. This wass fixed by https://github.com/deepmind/distrax/pull/177, which is part of the distrax v0.1.3 release.
To fix the issue, you should update to distrax v0.1.3 or later.