Search code examples
pythonnumpyindexing

How to index an ndarray of unknown dimensions


I have a python code that uses a numpy array to store multi-dimensional data. The dimensionality of the array is determined at runtime so I cannot know in advance the exact number of dimensions, which can be from 3 to 6. One thing I know, though, is that when a dimension exists in the array, it has 100 elements. So right now my analysis code looks something like this:

ndim = myarray.ndim
if ndim == 2:
    for p0 in range(100):
        do_something(myarray[:, p0])
if ndim == 3:
    for p0 in range(100):
        for p1 in range(100):
            do_something(myarray[:, p0, p1])
elif ndim == 4:
    for p0 in range(100):
        for p1 in range(100):
            for p2 in range(100):
                do_something(myarray[:, p0, p1, p2])
elif ndim == 5:
    for p0 in range(100):
        for p1 in range(100):
            for p2 in range(100):
                for p3 in range(100):
                    do_something(myarray[:, p0, p1, p2, p3])
elif ndim == 6:
    for p0 in range(100):
        for p1 in range(100):
            for p2 in range(100):
                for p3 in range(100):
                    for p4 in range(100):
                        do_something(myarray[:, p0, p1, p2, p3, p4])

Of course this works, but I find the code not very elegant and very much cluttered. Is there a better way to do this? I am pretty sure there must be a way to index the array without knowing a priori the number of dimensions, but I cannot find the proper function in the numpy documentation.


Solution

  • One generic solution is to work with np.ndindex, which returns an iterator over all nd indices given a certain shape. In your case you only want to iterate over the non-batch (all except the first) axes. This can be achieved by selecting the corresponding part of the shape tuple. See the following minimal example:

    import numpy as np
    
    myarray = np.arange(24).reshape(2, 3, 4)
    
    shape = myarray.shape
    
    def do_something(array):
        print(array)
    
    for idx in np.ndindex(shape[1:]):
        do_something(myarray[(...,) + idx])
    

    Which prints:

    [ 0 12]
    [ 1 13]
    [ 2 14]
    [ 3 15]
    [ 4 16]
    [ 5 17]
    [ 6 18]
    [ 7 19]
    [ 8 20]
    [ 9 21]
    [10 22]
    [11 23]
    

    The code works the same way no matter the dimension of the array.

    I hope this helps!