Search code examples
dynamicjitjax

JAX dynamic slice inside of control flow function


I would like to do dynamic slicing inside of lax.while_loop() using a variable carried over, getting an error as below. I know in the case of a simple function, I can pass the variable as a static value, using partial , but how can I handle the case in which the variable (in my case length) is carried over?

new_u = lax.dynamic_slice(u,(0,0),(0,length-1))
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Shapes must be 1D sequences of concrete values of integer type, got (0, Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=2/0)>).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.

This is how I coded. This code is just to illustrate the problem. What I would like to do is to extract a part of u and do some operations. Thank you.

import jax.numpy as jnp
import numpy as np
from jax import lax
from functools import partial
import jax

u = jnp.array([[1,2,3,4,5],[0,0,0,0,0]])

def body_fun(carry):
    length, sum_u = carry
    new_u = lax.dynamic_slice(u,(0,0),(0,length-1))
    #new_u = lax.dynamic_slice(u,(0,0),(2,4))
    jax.debug.print("new_u:{}", new_u)
    new_sum_u = jnp.sum(new_u)
    new_length = length -1
    return (new_length, new_sum_u)

def cond_fun(carry):
    length, sum_u = carry
    keep_condition = sum_u < 5
    return keep_condition

init_carry = (5,10)
out = lax.while_loop(cond_fun, body_fun, init_carry)
print(out)

Solution

  • The problem is that you are attempting to construct a dynamically-shaped array, and JAX does not support dynamically-shaped arrays (length is a dynamic variable in your loop). See JAX Sharp Bits: Dynamic Shapes for more.

    A typical strategy in these cases is to use a statically-sized array while masking out a dynamic range of values; in your case, you could use a value of 0 for the masked values so that they don't contribute to the sum. It might look like this:

    def body_fun(carry):
        length, sum_u = carry
        idx = jnp.arange(u.shape[1])
        new_u = jnp.where(idx < length, u, 0)
        jax.debug.print("new_u:{}", new_u)
        new_sum_u = jnp.sum(new_u)
        new_length = length -1
        return (new_length, new_sum_u)
    

    (Side-note: it seems like you were using dynamic_slice in hopes that you could generate dynamic array shapes, but the dynamic in dynamic_slice refers to the dynamic offset, not a dynamic size).