Search code examples
pythonnormalizationjax

Jax dynamic slicing tracer array


To make this brief: I wrote the following codes:

import jax
import jax.numpy as np

labels=np.array([0,0,0,0,1,1,1,1,2,2,2,2])
logits=np.array([1,2,3,4,5,6,7,8,9,10,11,12])

def body_func(carry,x):
    start_idx,arr=carry
    print(jax.lax.dynamic_slice(arr, [0], [jax.lax.tie_in(x, start_idx+1)]))
    carry=(start_idx,arr)
    return carry, carry

slices,=np.where(np.diff(labels)!=0)
print(jax.lax.scan(body_func,(0,logits),np.array(slices)))

but got

TypeError: Shapes must be 1D sequences of concrete values of integer type, got [Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>].
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function body_func at /path/test.py:10 for scan. This concrete value was not available in Python because it depends on the value of the argument carry[0].

Here's the full situation: I'm trying to develop a model to do phase recognition tasks, and I would like to normalize my logits phase by phase using jax. For example, suppose I have the phase labels and logits:

labels=np.array([0,0,0,0,1,1,1,1,2,2,2,2])
logits=np.array([1,2,3,4,5,6,7,8,9,10,11,12])

I would like to normalize the first 4 elements in logits where in the phase labels they all belong to phase 0. Then the next 4 elements, because in the phase labels they all belong to phase 1. So the normalized logits should look like:

normalized_logits=[0,0.33,0.66,1.0,0,0.33,0.66,1.0,0,0.33,0.66,1.0]

Here's what tried:

import jax
import jax.numpy as np

labels=np.array([0,0,0,0,1,1,1,1,2,2,2,2])
logits=np.array([1,2,3,4,5,6,7,8,9,10,11,12])

def min_max_normalization(x):
    return (x - np.min(x)) / (np.max(x) - np.min(x))

def body_func(carry,x):
    jax.debug.print("carry is {}",carry)
    jax.debug.print("x is {}",x)
    start_idx,arr=carry
    print(jax.lax.dynamic_slice(arr, [0], [jax.lax.tie_in(x, start_idx+1)]))
    print(min_max_normalization(jax.lax.dynamic_slice(arr, [start_idx], [jax.lax.tie_in(x, x-start_idx+1)])))
    print(jax.lax.dynamic_slice(arr, [x+1], [jax.lax.tie_in(x, len(arr)-x-1)]))
    carry=(start_idx,arr)
    return carry, carry

slices,=np.where(np.diff(labels)!=0)
print(jax.lax.scan(body_func,(0,logits),np.array(slices)))

Basically, this is a debug version, the actual return value should concatenate three dynamically sliced array together. But I'm getting the error below:

Traceback (most recent call last):
  File "/path/test.py", line 21, in <module>
    print(jax.lax.scan(body_func,(0,b),np.array(c)))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/lax/control_flow/loops.py", line 250, in scan
    init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init)
                                                                ^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/lax/control_flow/loops.py", line 236, in _create_jaxpr
    jaxpr, consts, out_tree = _initial_style_jaxpr(
                              ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/lax/control_flow/common.py", line 64, in _initial_style_jaxpr
    jaxpr, consts, out_tree = _initial_style_open_jaxpr(
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/lax/control_flow/common.py", line 58, in _initial_style_open_jaxpr
    jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py", line 2155, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "path/test.py", line 14, in body_func
    print(jax.lax.dynamic_slice(arr, [0], [jax.lax.tie_in(int(1), start_idx+1)]))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/lax/slicing.py", line 110, in dynamic_slice
    static_sizes = core.canonicalize_shape(slice_sizes)  # type: ignore
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/core.py", line 2086, in canonicalize_shape
    raise _invalid_shape_error(shape, context)
jax._src.traceback_util.UnfilteredStackTrace: TypeError: Shapes must be 1D sequences of concrete values of integer type, got [Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>].
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function body_func at /Users/wuhaoyang/Documents/Research/Project_Surgical_Robot/Code/SSM_Med/test.py:10 for scan. This concrete value was not available in Python because it depends on the value of the argument carry[0].

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "path/test.py", line 21, in <module>
    print(jax.lax.scan(body_func,(0,b),np.array(c)))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "path/test.py", line 14, in body_func
    print(jax.lax.dynamic_slice(arr, [0], [jax.lax.tie_in(int(1), start_idx+1)]))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Shapes must be 1D sequences of concrete values of integer type, got [Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>].
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function body_func at /path/test.py:10 for scan. This concrete value was not available in Python because it depends on the value of the argument carry[0].

The reason why I'm not simply using a for loop is that I'm later going to wrap this function into another one that uses jit compile, so I want to do this with pure jax API. Any help is appreciated, please tell me if you need more information.


Solution

  • JAX arrays used in transformations like jit, vmap, and scan must always be statically-shaped (see Sharp bits: Dynamic Shapes for some discussion of this).

    dynamic_slice allows you to slice a static length at a dynamic position, while you're trying to use it to slice a dynamic length at a static position, and thus you're seeing this concretization error.

    To solve your problem, I would avoid scan and instead use JAX's segment_min and segment_max functions to compute the output in a vectorized rather than iterative manner:

    import jax
    import jax.numpy as jnp
    
    labels = jnp.array([0,0,0,0,1,1,1,1,2,2,2,2])
    logits = jnp.array([1,2,3,4,5,6,7,8,9,10,11,12])
    
    l_min = jax.ops.segment_min(logits, labels)[labels]
    l_max = jax.ops.segment_max(logits, labels)[labels]
    
    normalized_logits = (logits - l_min) / (l_max - l_min)
    print(normalized_logits)
    # [0.         0.33333334 0.6666667  1.         0.         0.33333334
    #  0.6666667  1.         0.         0.33333334 0.6666667  1.        ]
    

    If you want this to be compatible with jit and other transformations, you'll need to pass a static num_segments argument to your segment reductions to specify an upper-bound for the number of segments present:

    l_min = jax.ops.segment_min(logits, labels, num_segments=3)[labels]
    l_max = jax.ops.segment_max(logits, labels, num_segments=3)[labels]