Search code examples
pythonnumpynumpy-ndarray

How to repeat the inner row of a 3D matrix(NxMxV) to get a new matrix(Nx2MxV) using np.repeat?


Here is an example of what I want:

original matrix:
[[[5 1 4]]

 [[0 9 5]]

 [[8 0 9]]]
matrix I want:
[[[5 1 4]
  [5 1 4]]

 [[0 9 5]
 [0 9 5]]

 [[8 0 9]
 [8 0 9]]]

I have tried np.repeat(A, 2, axis=0), which apparently does not work since it gives the output:

[[[5 1 4]]

 [[5 1 4]]

 [[0 9 5]]

 [[0 9 5]]

 [[8 0 9]]]

Solution

  • You want to repeat on axis=1:

    np.repeat(A, 2, axis=1)
    

    Output:

    array([[[5, 1, 4],
            [5, 1, 4]],
    
           [[0, 9, 5],
            [0, 9, 5]],
    
           [[8, 0, 9],
            [8, 0, 9]]])
    

    NB. remember to check the shape of your arrays: A.shape -> (3, 1, 3). You want to make it (3, 2, 3), not (6, 1, 3).