dynamicjitjax

# JAX dynamic slice inside of control flow function

I would like to do dynamic slicing inside of lax.while_loop() using a variable carried over, getting an error as below. I know in the case of a simple function, I can pass the variable as a static value, using partial , but how can I handle the case in which the variable (in my case length) is carried over?

new_u = lax.dynamic_slice(u,(0,0),(0,length-1))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Shapes must be 1D sequences of concrete values of integer type, got (0, Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=2/0)>).
If using jit, try using static_argnums or applying jit to smaller subfunctions.


This is how I coded. This code is just to illustrate the problem. What I would like to do is to extract a part of u and do some operations. Thank you.

import jax.numpy as jnp
import numpy as np
from jax import lax
from functools import partial
import jax

u = jnp.array([[1,2,3,4,5],[0,0,0,0,0]])

def body_fun(carry):
length, sum_u = carry
new_u = lax.dynamic_slice(u,(0,0),(0,length-1))
#new_u = lax.dynamic_slice(u,(0,0),(2,4))
jax.debug.print("new_u:{}", new_u)
new_sum_u = jnp.sum(new_u)
new_length = length -1
return (new_length, new_sum_u)

def cond_fun(carry):
length, sum_u = carry
keep_condition = sum_u < 5
return keep_condition

init_carry = (5,10)
out = lax.while_loop(cond_fun, body_fun, init_carry)
print(out)


Solution

• The problem is that you are attempting to construct a dynamically-shaped array, and JAX does not support dynamically-shaped arrays (length is a dynamic variable in your loop). See JAX Sharp Bits: Dynamic Shapes for more.

A typical strategy in these cases is to use a statically-sized array while masking out a dynamic range of values; in your case, you could use a value of 0 for the masked values so that they don't contribute to the sum. It might look like this:

def body_fun(carry):
length, sum_u = carry
idx = jnp.arange(u.shape[1])
new_u = jnp.where(idx < length, u, 0)
jax.debug.print("new_u:{}", new_u)
new_sum_u = jnp.sum(new_u)
new_length = length -1
return (new_length, new_sum_u)


(Side-note: it seems like you were using dynamic_slice in hopes that you could generate dynamic array shapes, but the dynamic in dynamic_slice refers to the dynamic offset, not a dynamic size).