Search code examples
numpymultidimensional-arraydiagonal

Diagonals of a multidimensional numpy array


Is there a more pythonic way of doing the following:

import numpy as np
def diagonal(A):
    (x,y,y) = A.shape
    diags = []
    for a in A: diags.append(np.diagonal(a))
    result = np.vstack(diags)
    assert result.shape == (x,y)
    return result

Solution

  • Approach #1

    A clean way would be with np.diagonal on a transposed version of input array, like so -

    np.diagonal(A.T)
    

    Basically, we are flipping the dimensions of the input array with A.T to let np.diagonal use the last two axes for extracting the diagonal elements along, because by default it would have otherwise used the first two axes instead. The best thing is this would work for arrays of any number of dimensions.

    Approach #2

    Here's another approach using a combination of advanced and basic indexing -

    m,n = A.shape[:2]
    out = A[np.arange(m)[:,None],np.eye(n,dtype=bool)]
    

    One can also use some reshaping with basic indexing -

    out = A.reshape(m,-1)[:,np.eye(n,dtype=bool).ravel()]
    

    Sample run -

    In [87]: A
    Out[87]: 
    array([[[73, 52, 62],
            [20,  7,  7],
            [ 1, 68, 89]],
    
           [[15, 78, 98],
            [24, 22, 35],
            [19,  1, 91]],
    
           [[ 5, 37, 64],
            [22,  4, 43],
            [84, 45, 12]],
    
           [[24, 45, 42],
            [70, 45,  1],
            [ 6, 48, 60]]])
    
    In [88]: np.diagonal(A.T)
    Out[88]: 
    array([[73,  7, 89],
           [15, 22, 91],
           [ 5,  4, 12],
           [24, 45, 60]])
    
    In [89]: m,n = A.shape[:2]
    
    In [90]: A[np.arange(m)[:,None],np.eye(n,dtype=bool)]
    Out[90]: 
    array([[73,  7, 89],
           [15, 22, 91],
           [ 5,  4, 12],
           [24, 45, 60]])