Search code examples
pythonnumpyjax

weird shape when indexing a jax array


I am experiencing a weird issue when indexing a Jax array using a list. If I place a debugger in the middle of my code, I have the following:

indexing inside the code

This array are created by convering a numpy array.

However, when I try this in a new instance of Python, I have the correct behavior: [indexing in a new instance

What is it happening?


Solution

  • This is working as expected. JAX follows the semantics of NumPy indexing, and in the case of advanced indexing with multiple scalars and integer arrays separated by slices, the indexed dimensions are combined via broadcasting and moved to the front of the output array. You can read more about the details of this kind of indexing in the NumPy documentation: https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing. In particular:

    Two cases of index combination need to be distinguished:

    • The advanced indices are separated by a slice, Ellipsis or newaxis. For example x[arr1, :, arr2].
    • The advanced indices are all next to each other. For example x[..., arr1, arr2, :] but not x[arr1, :, 1] since 1 is an advanced index in this regard.

    In the first case, the dimensions resulting from the advanced indexing operation come first in the result array, and the subspace dimensions after that. In the second case, the dimensions from the advanced indexing operations are inserted into the result array at the same spot as they were in the initial array

    The code in your program falls under the first case, while the code in your separate interpreter falls under the second case. This is why you're seeing different results.

    Here's a concise example of this difference:

    >>> import numpy as np
    >>> x = np.zeros((3, 4, 5))
    
    >>> x[0, :, [1, 2]].shape  # size-2 dimension moved to front
    (2, 4)
    
    >>> x[:, 0, [1, 2]].shape  # size-2 dimension not moved to front
    (3, 2)