Search code examples
pythonarraysnumpy

Multidimensional indexing in Numpy with tuples as indices for certain axes


I have a 3D numpy array which is really an array of matrices. I want to set the diagonals to zero by using the following approach. When I print the tuple and din it is exactly the same, but it returns different views of the array.

m = np.random.normal(0, 0.2, (10, 4, 4))
din = np.diag_indices(m.shape[1], ndim = 2)

m[:, np.array([0,1,2,3]), np.array([0,1,2,3])]) # It returns an array of diagonals as expected
m[:, tuple(din)] # It returns the array

What did I miss here?


Solution

  • As described in comments, you need to unpack the indices.

    Since python 3.11, you can use:

    m[:, *din]
    

    Output:

    array([[ 8.61622699e-02, -1.46919069e-01, -9.37771599e-02,
             1.94698315e-03],
           [ 1.60933774e-01, -2.77077615e-02, -1.74135776e-01,
            -1.72223723e-01],
           [-1.54804225e-01,  1.08146714e-01,  2.51844877e-01,
            -2.91622737e-02],
           [ 1.22213756e-02,  1.59703456e-02, -1.41757563e-01,
            -5.02470362e-02],
           [ 1.49296012e-01, -9.60208199e-03, -4.82484338e-01,
             1.58012139e-02],
           [-3.09847219e-01, -1.13959996e-01, -6.71019475e-01,
             3.17810448e-01],
           [ 2.04860543e-04, -2.16311908e-01,  1.39098046e-01,
            -1.40102017e-01],
           [-5.82402679e-02,  2.55831587e-01, -3.74597159e-01,
             1.23205316e-01],
           [-1.23942861e-01,  1.40365188e-02, -2.16884333e-02,
            -2.08800511e-02],
           [ 1.02934324e-01, -1.81953630e-01,  2.35600757e-01,
            -2.29315601e-01]])
    

    This syntax is however not supported in older python versions, in this case you can build a single tuple:

    m[tuple((slice(None), *din))]
    
    # or
    m[(slice(None), *din)]