Search code examples
pythonnumpymatrix-indexing

Efficiently getting i-th column of i-th 2D slice of 3D NumPy array, for all i


Suppose I have a NumPy array A of shape (N,N,N). From this, I form a 2D array B of shape (N,N) as follows:

B = np.column_stack( tuple(A[i,:,i] for i in range(N)) )

In other words, for the i-th 2D slice of A, I take it's i-th column; I then stack these columns to form B.

My question is:

Is there a more efficient way (NumPy indexing/slicing) to construct B from A; mainly, is it possible to eliminate the inner for loop over the 2D slices of A?


Solution

  • You can use advanced indexing:

    idx = np.arange(N)  # or idx = range(N)
    A[idx,:,idx].T
    

    Example:

    import numpy as np
    A = np.arange(27).reshape(3,3,3)
    
    idx = np.arange(3)
    A[idx,:,idx].T
    #array([[ 0, 10, 20],
    #       [ 3, 13, 23],
    #       [ 6, 16, 26]])
    
    np.column_stack( tuple(A[i,:,i] for i in range(3)) )
    #array([[ 0, 10, 20],
    #       [ 3, 13, 23],
    #       [ 6, 16, 26]])
    

    Timing: it is faster for a large array

    def adv_index(N):
        idx = range(N)
        return A[idx,:,idx].T
    
    N = 100
    import numpy as np
    A = np.arange(N*N*N).reshape(N,N,N)
    ​    
    %timeit np.column_stack(tuple(A[i,:,i] for i in range(N)))
    # The slowest run took 4.01 times longer than the fastest. This could mean that an intermediate result is being cached.
    # 1000 loops, best of 3: 210 µs per loop
    
    %timeit adv_index(N)
    # The slowest run took 5.87 times longer than the fastest. This could mean that an intermediate result is being cached.
    # 10000 loops, best of 3: 51.1 µs per loop
    
    (np.column_stack(tuple(A[i,:,i] for i in range(N))) == adv_index(N)).all()
    # True