Search code examples
pythonnumpyreshape

"Reshape" along a specific axis - Numpy


I have a 4D numpy array. For example I have 2 batches of 3 two-dimensional matrices with a shape of (2, 3, 4, 5). In each batch, I want to concatenate the three 2D matrices horizontally (over the last dimension).

The output shape should be of shape (2, 4, 5 * 3).


A smaller example for reproduction with shape (2, 3, 2, 2):

a1, a2, a3, a4 = 2, 3, 2, 2
arr = np.arange(a1*a2*a3*a4).reshape((a1, a2, a3, a4))

[[[[ 0  1]
   [ 2  3]]

  [[ 4  5]
   [ 6  7]]

  [[ 8  9]
   [10 11]]]


 [[[12 13]
   [14 15]]

  [[16 17]
   [18 19]]

  [[20 21]
   [22 23]]]]

It's first row (if examining the last dimension) should be: [0, 1, 4, 5, 8, 9]

Thanks in advance.


Solution

  • While some form of concatenate can be used, transpose often works for this kind of problem:

    In [532]: arr.transpose(0,2,1,3).reshape(2,2,6)
    Out[532]: 
    array([[[ 0,  1,  4,  5,  8,  9],
            [ 2,  3,  6,  7, 10, 11]],
    
           [[12, 13, 16, 17, 20, 21],
            [14, 15, 18, 19, 22, 23]]])
    

    Without the final reshape, the first "line" is

    In [533]: arr.transpose(0,2,1,3)
    Out[533]: 
    array([[[[ 0,  1],
             [ 4,  5],
             [ 8,  9]],
    

    And for the example with all different dimensions

    In [534]: x=np.ones((2,3,4,5))
    In [535]: x.transpose(0,2,1,3).shape
    Out[535]: (2, 4, 3, 5)
    In [536]: x.transpose(0,2,1,3).reshape(2,4,3*5).shape
    Out[536]: (2, 4, 15)
    

    Sometimes figuring out the transpose takes some tial-and-error. But here you want to keep the first dimension as is, and also the last, so we are left with swapping the two middle ones.

    By itself transpose makes a view, but the last reshape has to make a copy since it's reordering the 'raveled' values.