Search code examples
pythonnumpytensorflow

How to implement pixel-shuffle


In tensorflow, there is a pixel-shuffle method called depth_to_space. What it does is the following: Suppose we have an image (an array) with dimensions (4,4,4). The above method shuffles the values of this array so that we get an array of size (16,16,1) in a way depicted in the image below:

A depth-to-space pixel-shuffle operation

I tried now for a few hours to recreate this method in numpy using plane numpy functions like reshape, transpose etc. however I am not able to succeed. Does anyone know how to implement it?

A very similar problem can be found in How to implement tf.space_to_depth with numpy?. However, this question considers the space_to_depth method, which is the inverse operation.


Solution

  • Here is a "channels-first" solution (i.e. assuming your array dimensions are ordered channels×height×width):

    import numpy as np
    
    # Create some data ("channels-first" version)
    a = (np.ones((1, 3, 3), dtype=int) *
         np.arange(1,5)[:, np.newaxis, np.newaxis])  # 4×3×3
    
    c, h, w = a.shape  # channels, height, width
    p = int(np.sqrt(c))  # height and width of one "patch"
    assert p * p == c  # Sanity-check
    a_flat = a.reshape(p, p, h, w).transpose(2, 0, 3, 1).reshape(h * p, w * p)  # 6×6
    
    print(a)
    # [[[1 1 1]
    #   [1 1 1]
    #   [1 1 1]]
    
    #  [[2 2 2]
    #   [2 2 2]
    #   [2 2 2]]
    
    #  [[3 3 3]
    #   [3 3 3]
    #   [3 3 3]]
    
    #  [[4 4 4]
    #   [4 4 4]
    #   [4 4 4]]]
    
    print(a_flat)
    # [[1 2 1 2 1 2]
    #  [3 4 3 4 3 4]
    #  [1 2 1 2 1 2]
    #  [3 4 3 4 3 4]
    #  [1 2 1 2 1 2]
    #  [3 4 3 4 3 4]]
    

    And here is the corresponding "channels-last" version (i.e. assuming your array dimensions are ordered height×width×channels):

    import numpy as np
    
    # Create some data ("channels-last" version)
    a = np.ones((3, 3, 1), dtype=int) * np.arange(1, 5)  # 3×3×4
    
    h, w, c = a.shape  # height, width, channels
    p = int(np.sqrt(c))  # height and width of one "patch"
    assert p * p == c  # Sanity-check
    a_flat = a.reshape(h, w, p, p).transpose(0, 2, 1, 3).reshape(h * p, w * p)  # 6×6
    
    print(a)
    # [[[1 2 3 4]
    #   [1 2 3 4]
    #   [1 2 3 4]]
    
    #  [[1 2 3 4]
    #   [1 2 3 4]
    #   [1 2 3 4]]
    
    #  [[1 2 3 4]
    #   [1 2 3 4]
    #   [1 2 3 4]]]
    
    print(a_flat)
    # [[1 2 1 2 1 2]
    #  [3 4 3 4 3 4]
    #  [1 2 1 2 1 2]
    #  [3 4 3 4 3 4]
    #  [1 2 1 2 1 2]
    #  [3 4 3 4 3 4]]
    

    In both cases, the idea is the same:

    1. With the first reshape, we split the channel dimension (or "depth") of length c into what will become a p×p patch (note that p·p=c, where p and c correspond to t and in the question).
    2. With transpose, we place the patch height behind the current image height and the patch width behind the current image width.
    3. With the second reshape, we fuse the current image height and patch height into the new image height, and the current image width and patch width into the new image width.

    Update: Using einops

    Using rearrange() from the einops package, as suggested in Mercury's comment, corresponding solutions could look as follows:

    import numpy as np
    from einops import rearrange
    
    # Channels first
    a = (np.ones((1, 3, 3), dtype=int) *
         np.arange(1,5)[:, np.newaxis, np.newaxis])  # 4×3×3
    p = int(np.sqrt(a.shape[0]))  # height and width of one "patch"
    a_flat = rearrange(a, "(hp wp) h w -> (h hp) (w wp)", hp=p)
    
    # Channels last
    a = np.ones((3, 3, 1), dtype=int) * np.arange(1, 5)  # 3×3×4
    p = int(np.sqrt(a.shape[-1]))  # height and width of one "patch"
    a_flat = rearrange(a, "h w (hp wp) -> (h hp) (w wp)", hp=p)