Search code examples
arraysnumpyslicebroadcaststride

Slice sets of columns in numpy


Consider a numpy array as such:

>>> a = np.array([[1, 2, 3, 0, 1], [2, 3, 2, 2, 2], [0, 3, 3, 2, 2]])
>>> a
array([[1, 2, 3, 0, 1],
       [2, 3, 2, 2, 2],
       [0, 3, 3, 2, 2]])

And an array which contains couples of column indexes to slice (a specific column can appear in multiple couples):

b = [[0,1], [0,3], [1,4]]

How can I slice/broadcast/stride a using b to get a result as such:

array([[[1, 2],
        [2, 3],
        [0, 3]],

       [[1, 0],
        [2, 2],
        [0, 2]],

       [[2, 1],
        [3, 2],
        [3, 2]]])

Solution

  • Use b as column indices to subset the array and then transpose the result:

    a[:, b].swapaxes(0, 1)
    
    # array([[[1, 2],
    #         [2, 3],
    #         [0, 3]],
    #        [[1, 0],
    #         [2, 2],
    #         [0, 2]],
    #        [[2, 1],
    #         [3, 2],
    #         [3, 2]]])