Search code examples
pythonnumpynumpy-ndarray

Can I create a multidimensional array containing a unit matrix without nested loops?


Suppose I have a Numpy array n indices, where the first n-2 represents some counting indices and the last 2 indices represent a square MxM matrix. I want to initialize this structure so it will contain copies of the unit matrix.

Example (here N=3, M=2):

A = numpy.zeros((3,2,2))
for k in range(3):
     A[k,:,:] = numpy.eye(2)

Another Example (here N=4, M=5):

B = numpy.zeros((3,4,5,5))
for k1 in range(3):
     for k2 in range(4):
         B[k1,k2,:,:] = numpy.eye(5)

Is there a way to do this without relying on nested loops?


Solution

  • You can repeat:

    A = np.repeat(np.eye(2)[None], 3, axis=0)
    

    For more complex cases, combined with reshape:

    extra = (3, 4)
    M = 5
    B = np.repeat(np.eye(M)[None], np.prod(extra), axis=0).reshape(extra+(M, M))
    

    Or with tile:

    extra = (3, 4)
    B = np.tile(np.eye(5), extra+(1, 1))
    

    Or, from numpy.zeros using indexing:

    B = np.zeros((3, 4, 5, 5))
    x = np.arange(B.shape[-1])
    B[..., x, x] = 1