Search code examples
pythonnumpynumpy-ndarraynumpy-slicing

Selecting numpy array on the last axis


I have a 3D array and a 2D array of indices. How can I select on the last axis?

import numpy as np

# example array
shape = (4,3,2)
x = np.random.uniform(0,1, shape)

# indices
idx = np.random.randint(0,shape[-1], shape[:-1])

Here is a loop that can give the desired result. But there should be an efficient vectorized way to do this.

result = np.zeros(shape[:-1])
for i in range(shape[0]):
    for j in range(shape[1]):
        result[i,j] = x[i,j,idx[i,j]]

Solution

  • A possible solution:

    np.take_along_axis(x, np.expand_dims(idx, axis=-1), axis=-1).squeeze(axis=-1)
    

    Alternatively,

    i, j = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), indexing='ij')
    x[i, j, idx]