Search code examples
pythonarraysnumpyslice

numpy take can't index using slice


According to the numpy docs for take it does the same thing as “fancy” indexing (indexing arrays using arrays). However, it can be easier to use if you need elements along a given axis.

However, unlike "fancy" or regular numpy indexing, using slices as indices appears to be not supported:

In [319]: A = np.arange(20).reshape(4, 5)

In [320]: A[..., 1:4]
Out[320]: 
array([[ 1,  2,  3],
       [ 6,  7,  8],
       [11, 12, 13],
       [16, 17, 18]])

In [321]: np.take(A, slice(1, 4), axis=-1)
TypeError: long() argument must be a string or a number, not 'slice'

What is the best way to index an array using slices along an axis only known at runtime?


Solution

  • According to the numpy docs for take it does the same thing as “fancy” indexing (indexing arrays using arrays).

    The second argument to np.take has to be array-like (an array, a list, a tuple etc.), not a slice object. You could construct an indexing array or list that does your desired slicing:

    a = np.arange(24).reshape(2, 3, 4)
    
    np.take(a, slice(1, 4, 2), 2)
    # TypeError: long() argument must be a string or a number, not 'slice'
    
    np.take(a, range(1, 4, 2), 2)
    # array([[[ 1,  3],
    #         [ 5,  7],
    #         [ 9, 11]],
    
    #        [[13, 15],
    #         [17, 19],
    #         [21, 23]]])
    

    What is the best way to index an array using slices along an axis only known at runtime?

    What I often prefer do is to use np.rollaxis to make the axis to be indexed into the first one, do my indexing, then roll it back into its original position.

    For example, let's say I want odd-numbered slices of a 3D array along its 3rd axis:

    sliced1 = a[:, :, 1::2]
    

    If I then wanted to specify the axis to slice along at runtime, I could do it like this:

    n = 2    # axis to slice along
    
    sliced2 = np.rollaxis(np.rollaxis(a, n, 0)[1::2], 0, n + 1)
    
    assert np.all(sliced1 == sliced2)
    

    To unpack that one-liner a bit:

    # roll the nth axis to the 0th position
    np.rollaxis(a, n, 0)
    
    # index odd-numbered slices along the 0th axis
    np.rollaxis(a, n, 0)[1::2]
    
    # roll the 0th axis back so that it lies before position n + 1 (note the '+ 1'!)
    np.rollaxis(np.rollaxis(a, n, 0)[1::2], 0, n + 1)