Is there a more pythonic way of doing the following:
import numpy as np
def diagonal(A):
(x,y,y) = A.shape
diags = []
for a in A: diags.append(np.diagonal(a))
result = np.vstack(diags)
assert result.shape == (x,y)
return result
Approach #1
A clean way would be with np.diagonal
on a transposed version of input array, like so -
np.diagonal(A.T)
Basically, we are flipping the dimensions of the input array with A.T
to let np.diagonal
use the last two axes for extracting the diagonal elements along, because by default it would have otherwise used the first two axes instead. The best thing is this would work for arrays of any number of dimensions.
Approach #2
Here's another approach using a combination of advanced and basic indexing
-
m,n = A.shape[:2]
out = A[np.arange(m)[:,None],np.eye(n,dtype=bool)]
One can also use some reshaping with basic indexing
-
out = A.reshape(m,-1)[:,np.eye(n,dtype=bool).ravel()]
Sample run -
In [87]: A
Out[87]:
array([[[73, 52, 62],
[20, 7, 7],
[ 1, 68, 89]],
[[15, 78, 98],
[24, 22, 35],
[19, 1, 91]],
[[ 5, 37, 64],
[22, 4, 43],
[84, 45, 12]],
[[24, 45, 42],
[70, 45, 1],
[ 6, 48, 60]]])
In [88]: np.diagonal(A.T)
Out[88]:
array([[73, 7, 89],
[15, 22, 91],
[ 5, 4, 12],
[24, 45, 60]])
In [89]: m,n = A.shape[:2]
In [90]: A[np.arange(m)[:,None],np.eye(n,dtype=bool)]
Out[90]:
array([[73, 7, 89],
[15, 22, 91],
[ 5, 4, 12],
[24, 45, 60]])