I would like to do dynamic slicing inside of lax.while_loop()
using a variable carried over, getting an error as below. I know in the case of a simple function, I can pass the variable as a static value, using partial
, but how can I handle the case in which the variable (in my case length
) is carried over?
new_u = lax.dynamic_slice(u,(0,0),(0,length-1))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Shapes must be 1D sequences of concrete values of integer type, got (0, Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=2/0)>).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
This is how I coded. This code is just to illustrate the problem. What I would like to do is to extract a part of u
and do some operations. Thank you.
import jax.numpy as jnp
import numpy as np
from jax import lax
from functools import partial
import jax
u = jnp.array([[1,2,3,4,5],[0,0,0,0,0]])
def body_fun(carry):
length, sum_u = carry
new_u = lax.dynamic_slice(u,(0,0),(0,length-1))
#new_u = lax.dynamic_slice(u,(0,0),(2,4))
jax.debug.print("new_u:{}", new_u)
new_sum_u = jnp.sum(new_u)
new_length = length -1
return (new_length, new_sum_u)
def cond_fun(carry):
length, sum_u = carry
keep_condition = sum_u < 5
return keep_condition
init_carry = (5,10)
out = lax.while_loop(cond_fun, body_fun, init_carry)
print(out)
The problem is that you are attempting to construct a dynamically-shaped array, and JAX does not support dynamically-shaped arrays (length
is a dynamic variable in your loop). See JAX Sharp Bits: Dynamic Shapes for more.
A typical strategy in these cases is to use a statically-sized array while masking out a dynamic range of values; in your case, you could use a value of 0
for the masked values so that they don't contribute to the sum. It might look like this:
def body_fun(carry):
length, sum_u = carry
idx = jnp.arange(u.shape[1])
new_u = jnp.where(idx < length, u, 0)
jax.debug.print("new_u:{}", new_u)
new_sum_u = jnp.sum(new_u)
new_length = length -1
return (new_length, new_sum_u)
(Side-note: it seems like you were using dynamic_slice
in hopes that you could generate dynamic array shapes, but the dynamic
in dynamic_slice
refers to the dynamic offset, not a dynamic size).