Search code examples
pythonjax

Creating a jax array using existing jax arrays of different lengths throws error


I am using the following code to set a particular row of a jax 2D array to a particular value using jax arrays:

zeros_array = jnp.zeros((3, 8))
value = jnp.array([1,2,3,4])
value_2 = jnp.array([1])
value_3 = jnp.array([1,2])
values = jnp.array([value,value_2,value_3])
zeros_array = zeros_array.at[0].set(values)

But, I am receiving the following error:

ValueError: All input arrays must have the same shape.

Upon modifying the jnp to np (numpy) the error disappears. Is there any way to resolve this error? I know one walk around this would be to set each of the separate arrays in the 2D array using at[0,1].set(), at[0,2:n].set().


Solution

  • What you have in mind is a "ragged array", and no, there is not currently any way to do this in JAX. In older versions of NumPy, this will work by returning an array of dtype object, but in newer versions of NumPy this results in an error because object arrays are generally inconvenient and inefficient to work with (for example, there's no way to efficiently do the equivalent of the index update operation in your last line if the updates are stored in an object array).

    Depending on your use-case, there are several workarounds for this you might use in both JAX and NumPy, including storing the rows of your array as a list, or using a padded 2D array representation.

    I'll note also that the JAX team is exploring native support for ragged arrays (see e.g. https://github.com/google/jax/pull/16541) but it's still fairly far from being generally useful.