Search code examples
pythonnumpypytorch

Fast way to remove multiple rows by indices from a Pytorch or Numpy 2D array


I have a numpy array (and equivalently a Pytorch tensor) of shape Nx3. I also have a list of indices corresponding to rows, that I want to remove from this tensor. This list of indices is called remove_ixs. N is very big, about 5 million rows, and remove_ixs is 50k long. The way I'm doing it now is as follows:

mask = [i not in remove_ixs for i in range(my_array.shape[0])]
new_array = my_array[mask,:]

But the first line is just not terminating, takes forever. The above is in numpy code. An equivalent Pytorch code would also work for me.

Is there a faster way to do this with either numpy or pytorch?


Solution

  • You can create an initial mask (boolean) array that is True for the elements you want to remove and then invert it to give a mask of the elements you want to keep.

    remove_mask = np.zeros(my_array.shape[0], dtype=bool)
    remove_mask[remove_ixs] = True
    mask = ~remove_mask
        
    new_array = my_array[mask, :]
    

    Or start all True and do the opposite:

    mask = np.ones(my_array.shape[0], dtype=bool)
    mask[remove_ixs] = False
        
    new_array = my_array[mask, :]
    

    For some reason, the first version is faster for smaller arrays.