Search code examples
numpymultidimensional-arraynumpy-ndarrayarray-broadcastingnumpy-slicing

Taking different columns from each 2D slice of a 3D numpy array


Assume the following 3D numpy array:

array([[[4, 1, 3, 5, 0, 1, 5, 4, 3],
        [2, 3, 3, 2, 1, 0, 5, 5, 4],
        [5, 3, 0, 2, 2, 2, 5, 3, 2],
        [0, 3, 1, 0, 2, 4, 1, 1, 5],
        [2, 0, 0, 1, 4, 0, 3, 5, 3]],

       [[2, 2, 4, 1, 3, 4, 1, 1, 5],
        [2, 2, 3, 5, 5, 4, 0, 2, 0],
        [4, 0, 5, 3, 1, 3, 1, 1, 1],
        [4, 5, 0, 0, 5, 3, 3, 2, 4],
        [0, 3, 4, 5, 4, 5, 4, 2, 3]],

       [[1, 3, 2, 2, 0, 4, 5, 0, 2],
        [5, 0, 5, 2, 3, 5, 5, 3, 1],
        [0, 5, 3, 2, 2, 0, 4, 2, 3],
        [4, 4, 0, 3, 2, 1, 5, 3, 0],
        [0, 0, 2, 4, 0, 5, 2, 0, 0]]])

Given a list [3, 4, 8],

is it possible to slice the given tensor without using a for loop?

For example to take the 3rdth column from [0, :, :], 4th column from [1, :, :] and 8th column from [2, :, :] to obtain:

array([[5, 2, 2, 0, 1],
       [3, 5, 1, 5, 4],
       [2, 1, 3, 0, 0]])

Solution

  • Here's one way with np.take_along_axis -

    In [73]: idx = np.array([3,4,8])
    
    # a is input array
    In [72]: np.take_along_axis(a,idx[:,None,None],axis=2)[:,:,0]
    Out[72]: 
    array([[5, 2, 2, 0, 1],
           [3, 5, 1, 5, 4],
           [2, 1, 3, 0, 0]])
    

    Another with the explicit integer-indexing -

    In [79]: a[np.arange(len(idx)),:,idx]
    Out[79]: 
    array([[5, 2, 2, 0, 1],
           [3, 5, 1, 5, 4],
           [2, 1, 3, 0, 0]])