To make this brief: I wrote the following codes:
import jax
import jax.numpy as np
labels=np.array([0,0,0,0,1,1,1,1,2,2,2,2])
logits=np.array([1,2,3,4,5,6,7,8,9,10,11,12])
def body_func(carry,x):
start_idx,arr=carry
print(jax.lax.dynamic_slice(arr, [0], [jax.lax.tie_in(x, start_idx+1)]))
carry=(start_idx,arr)
return carry, carry
slices,=np.where(np.diff(labels)!=0)
print(jax.lax.scan(body_func,(0,logits),np.array(slices)))
but got
TypeError: Shapes must be 1D sequences of concrete values of integer type, got [Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>].
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function body_func at /path/test.py:10 for scan. This concrete value was not available in Python because it depends on the value of the argument carry[0].
Here's the full situation: I'm trying to develop a model to do phase recognition tasks, and I would like to normalize my logits phase by phase using jax. For example, suppose I have the phase labels and logits:
labels=np.array([0,0,0,0,1,1,1,1,2,2,2,2])
logits=np.array([1,2,3,4,5,6,7,8,9,10,11,12])
I would like to normalize the first 4 elements in logits where in the phase labels they all belong to phase 0. Then the next 4 elements, because in the phase labels they all belong to phase 1. So the normalized logits should look like:
normalized_logits=[0,0.33,0.66,1.0,0,0.33,0.66,1.0,0,0.33,0.66,1.0]
Here's what tried:
import jax
import jax.numpy as np
labels=np.array([0,0,0,0,1,1,1,1,2,2,2,2])
logits=np.array([1,2,3,4,5,6,7,8,9,10,11,12])
def min_max_normalization(x):
return (x - np.min(x)) / (np.max(x) - np.min(x))
def body_func(carry,x):
jax.debug.print("carry is {}",carry)
jax.debug.print("x is {}",x)
start_idx,arr=carry
print(jax.lax.dynamic_slice(arr, [0], [jax.lax.tie_in(x, start_idx+1)]))
print(min_max_normalization(jax.lax.dynamic_slice(arr, [start_idx], [jax.lax.tie_in(x, x-start_idx+1)])))
print(jax.lax.dynamic_slice(arr, [x+1], [jax.lax.tie_in(x, len(arr)-x-1)]))
carry=(start_idx,arr)
return carry, carry
slices,=np.where(np.diff(labels)!=0)
print(jax.lax.scan(body_func,(0,logits),np.array(slices)))
Basically, this is a debug version, the actual return value should concatenate three dynamically sliced array together. But I'm getting the error below:
Traceback (most recent call last):
File "/path/test.py", line 21, in <module>
print(jax.lax.scan(body_func,(0,b),np.array(c)))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/lax/control_flow/loops.py", line 250, in scan
init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init)
^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/lax/control_flow/loops.py", line 236, in _create_jaxpr
jaxpr, consts, out_tree = _initial_style_jaxpr(
^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/lax/control_flow/common.py", line 64, in _initial_style_jaxpr
jaxpr, consts, out_tree = _initial_style_open_jaxpr(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/lax/control_flow/common.py", line 58, in _initial_style_open_jaxpr
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py", line 2155, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "path/test.py", line 14, in body_func
print(jax.lax.dynamic_slice(arr, [0], [jax.lax.tie_in(int(1), start_idx+1)]))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/lax/slicing.py", line 110, in dynamic_slice
static_sizes = core.canonicalize_shape(slice_sizes) # type: ignore
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/core.py", line 2086, in canonicalize_shape
raise _invalid_shape_error(shape, context)
jax._src.traceback_util.UnfilteredStackTrace: TypeError: Shapes must be 1D sequences of concrete values of integer type, got [Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>].
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function body_func at /Users/wuhaoyang/Documents/Research/Project_Surgical_Robot/Code/SSM_Med/test.py:10 for scan. This concrete value was not available in Python because it depends on the value of the argument carry[0].
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "path/test.py", line 21, in <module>
print(jax.lax.scan(body_func,(0,b),np.array(c)))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "path/test.py", line 14, in body_func
print(jax.lax.dynamic_slice(arr, [0], [jax.lax.tie_in(int(1), start_idx+1)]))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Shapes must be 1D sequences of concrete values of integer type, got [Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>].
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function body_func at /path/test.py:10 for scan. This concrete value was not available in Python because it depends on the value of the argument carry[0].
The reason why I'm not simply using a for loop is that I'm later going to wrap this function into another one that uses jit compile, so I want to do this with pure jax API. Any help is appreciated, please tell me if you need more information.
JAX arrays used in transformations like jit
, vmap
, and scan
must always be statically-shaped (see Sharp bits: Dynamic Shapes for some discussion of this).
dynamic_slice
allows you to slice a static length at a dynamic position, while you're trying to use it to slice a dynamic length at a static position, and thus you're seeing this concretization error.
To solve your problem, I would avoid scan
and instead use JAX's segment_min
and segment_max
functions to compute the output in a vectorized rather than iterative manner:
import jax
import jax.numpy as jnp
labels = jnp.array([0,0,0,0,1,1,1,1,2,2,2,2])
logits = jnp.array([1,2,3,4,5,6,7,8,9,10,11,12])
l_min = jax.ops.segment_min(logits, labels)[labels]
l_max = jax.ops.segment_max(logits, labels)[labels]
normalized_logits = (logits - l_min) / (l_max - l_min)
print(normalized_logits)
# [0. 0.33333334 0.6666667 1. 0. 0.33333334
# 0.6666667 1. 0. 0.33333334 0.6666667 1. ]
If you want this to be compatible with jit
and other transformations, you'll need to pass a static num_segments
argument to your segment reductions to specify an upper-bound for the number of segments present:
l_min = jax.ops.segment_min(logits, labels, num_segments=3)[labels]
l_max = jax.ops.segment_max(logits, labels, num_segments=3)[labels]