Search code examples
pythonnumpynumpy-ndarraynumpy-slicing

Remove specific indices in each row of a numpy ndarray


I have integer arrays of the type:

import numpy as np

seed_idx = np.asarray([[0, 1],
                       [1, 2],
                       [2, 3],
                       [3, 4]], dtype=np.int_)

target_idx = np.asarray([[2,9,4,1,8],
                         [9,7,6,2,4],
                         [1,0,0,4,9],
                         [7,1,2,3,8]], dtype=np.int_)

For each row of target_idx, I want to select the elements whose indices are not the ones in seed_idx. The resulting array should thus be:

[[4,1,8],
 [9,2,4],
 [1,0,9],
 [7,1,2]]

In other words, I want to do something similar to np.take_along_axis(target_idx, seed_idx, axis=1), but excluding the indices instead of keeping them.

What is the most elegant way to do this? I find it surprisingly annoying to find something neat.


Solution

  • You can mask out the values you don't want with np.put_along_axis and then index the others:

    >>> np.put_along_axis(target_idx, seed_idx, -1, axis=1)
    >>> target_idx[np.where(target_idx != -1)].reshape(len(target_idx), -1)
    array([[4, 1, 8],
           [9, 2, 4],
           [1, 0, 9],
           [7, 1, 2]])
    

    If -1 is a valid value, use target_idx.min() - 1.