Search code examples
pythonarraysnumpyreshapetranspose

Why does .reshape(a, b) != .reshape(b, a).T?


I ran into this problem whilst flattening images. Consider the following array

>>> import numpy as np
>>> arr = np.array([[[1, 2],
                     [3, 4]],
                    [[1, 2],
                     [3, 4]]])
>>> arr.shape
(2, 2, 2)

I wish to reshape this into an 4x2 array where each vertical slice is a flattened image.

1. Using .reshape(2, 4).T

This achieves the desired result

>>> arr_flatten = arr.reshape(2, 4).T
>>> arr_flatten
[[1 1]
 [2 2]
 [3 3]
 [4 4]]

2. Using .reshape(4, 2)

This does not achieve the desired result

>>> arr_flatten = arr.reshape(4, 2)
>>> arr_flatten
[[1 2]
 [3 4]
 [1 2]
 [3 4]]

Why does this second approach not work?


Solution

  • The two reshapes are creating the array differently. Think of .reshape(2,4) as traversing the array while counting elements and every time it gets to 4 it creates a new row and restarts the count. So, with the count in parenthesis, it goes 1 (1), 2 (2), 3 (3), 4 (4) new row 1 (1), 2 (2), 3 (3), 4 (4). For .reshape(4,2), it counts up to 2 and starts a new row. So it goes, 1 (1), 2 (2) new row 3 (1), 4 (2) new row 1 (1), 2 (2) new row 3 (1), 4 (1).

    Hence, .reshape(2,4) produces

    array([[1, 2, 3, 4],
           [1, 2, 3, 4]])
    

    And .reshape(4,2) produces.

    array([[1, 2],
           [3, 4],
           [1, 2],
           [3, 4]])
    

    Clearly, their transposes aren't equal. As for why one gives the right answer, that all comes down to the meaning of the data and how it is ordered. In this case, because the images are in axis 0, the .reshape(2,4) traverses the entire axis 0 before creating a new row, so it gives the correct result.