To make this brief: I wrote the following codes:
import jax
import jax.numpy as np
def body_func(carry,x):
print(jax.lax.dynamic_slice(arr, [0], [jax.lax.tie_in(x, start_idx+1)]))
return carry, carry
but got
TypeError: Shapes must be 1D sequences of concrete values of integer type, got [Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>].
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function body_func at /path/ for scan. This concrete value was not available in Python because it depends on the value of the argument carry[0].
Here's the full situation: I'm trying to develop a model to do phase recognition tasks, and I would like to normalize my logits phase by phase using jax. For example, suppose I have the phase labels and logits:
I would like to normalize the first 4 elements in logits where in the phase labels they all belong to phase 0. Then the next 4 elements, because in the phase labels they all belong to phase 1. So the normalized logits should look like:
Here's what tried:
import jax
import jax.numpy as np
def min_max_normalization(x):
return (x - np.min(x)) / (np.max(x) - np.min(x))
def body_func(carry,x):
jax.debug.print("carry is {}",carry)
jax.debug.print("x is {}",x)
print(jax.lax.dynamic_slice(arr, [0], [jax.lax.tie_in(x, start_idx+1)]))
print(min_max_normalization(jax.lax.dynamic_slice(arr, [start_idx], [jax.lax.tie_in(x, x-start_idx+1)])))
print(jax.lax.dynamic_slice(arr, [x+1], [jax.lax.tie_in(x, len(arr)-x-1)]))
return carry, carry
Basically, this is a debug version, the actual return value should concatenate three dynamically sliced array together. But I'm getting the error below:
Traceback (most recent call last):
File "/path/", line 21, in <module>
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/lax/control_flow/", line 250, in scan
init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init)
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/lax/control_flow/", line 236, in _create_jaxpr
jaxpr, consts, out_tree = _initial_style_jaxpr(
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/lax/control_flow/", line 64, in _initial_style_jaxpr
jaxpr, consts, out_tree = _initial_style_open_jaxpr(
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/lax/control_flow/", line 58, in _initial_style_open_jaxpr
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/", line 314, in wrapper
return func(*args, **kwargs)
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/interpreters/", line 2155, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/interpreters/", line 2177, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/", line 188, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "path/", line 14, in body_func
print(jax.lax.dynamic_slice(arr, [0], [jax.lax.tie_in(int(1), start_idx+1)]))
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/lax/", line 110, in dynamic_slice
static_sizes = core.canonicalize_shape(slice_sizes) # type: ignore
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/", line 2086, in canonicalize_shape
raise _invalid_shape_error(shape, context)
jax._src.traceback_util.UnfilteredStackTrace: TypeError: Shapes must be 1D sequences of concrete values of integer type, got [Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>].
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function body_func at /Users/wuhaoyang/Documents/Research/Project_Surgical_Robot/Code/SSM_Med/ for scan. This concrete value was not available in Python because it depends on the value of the argument carry[0].
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "path/", line 21, in <module>
File "path/", line 14, in body_func
print(jax.lax.dynamic_slice(arr, [0], [jax.lax.tie_in(int(1), start_idx+1)]))
TypeError: Shapes must be 1D sequences of concrete values of integer type, got [Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>].
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function body_func at /path/ for scan. This concrete value was not available in Python because it depends on the value of the argument carry[0].
The reason why I'm not simply using a for loop is that I'm later going to wrap this function into another one that uses jit compile, so I want to do this with pure jax API. Any help is appreciated, please tell me if you need more information.
JAX arrays used in transformations like jit
, vmap
, and scan
must always be statically-shaped (see Sharp bits: Dynamic Shapes for some discussion of this).
allows you to slice a static length at a dynamic position, while you're trying to use it to slice a dynamic length at a static position, and thus you're seeing this concretization error.
To solve your problem, I would avoid scan
and instead use JAX's segment_min
and segment_max
functions to compute the output in a vectorized rather than iterative manner:
import jax
import jax.numpy as jnp
labels = jnp.array([0,0,0,0,1,1,1,1,2,2,2,2])
logits = jnp.array([1,2,3,4,5,6,7,8,9,10,11,12])
l_min = jax.ops.segment_min(logits, labels)[labels]
l_max = jax.ops.segment_max(logits, labels)[labels]
normalized_logits = (logits - l_min) / (l_max - l_min)
# [0. 0.33333334 0.6666667 1. 0. 0.33333334
# 0.6666667 1. 0. 0.33333334 0.6666667 1. ]
If you want this to be compatible with jit
and other transformations, you'll need to pass a static num_segments
argument to your segment reductions to specify an upper-bound for the number of segments present:
l_min = jax.ops.segment_min(logits, labels, num_segments=3)[labels]
l_max = jax.ops.segment_max(logits, labels, num_segments=3)[labels]