Search code examples
pythonalgorithmimagerecursioncombinations

Fastest way to combine image patches given as 4D array in python


Given a 4D array of size (N,W,H,3), where N is the number of patches, W,H are the width and height of an image patch and 3 is the number of color channels. Assume that these patches were generated by taking and original image I and dividing it up into small squares. The order by which this division happen is row by row. So if we divide our image into 3x3 patches (9 total) each back is 10x10pixels, then the 4D array will be (9,10,10,3) and the order of element in it will be [patch11,patch12,patch13,patch21,patch22,patch23,patch31,patch32,patch33].

Now my question is about the most efficient way to combine these patches back to produce the original image in python only using simply python functions and numpy (no PIL or OpenCV).

Thank you so much.

I can write a double for loop that does the job as below, but I'm wondering if there is a better algorithm that can provide faster performance:

import numpy as np

def reconstruct_image(patches, num_rows, num_cols):
    # num_rows and num_cols are the number of patches in the rows and columns respectively
    patch_height, patch_width, channels = patches.shape[1], patches.shape[2], patches.shape[3]

    # Initialize the empty array for the full image
    full_image = np.zeros((num_rows * patch_height, num_cols * patch_width, channels), dtype=patches.dtype)

    # Iterate over the rows and columns of patches
    for i in range(num_rows):
        for j in range(num_cols):
            # Get the index of the current patch in the 4D array
            patch_index = i * num_cols + j
            # Place the patch in the appropriate position in the full image
            full_image[i*patch_height:(i+1)*patch_height, j*patch_width:(j+1)*patch_width, :] = patches[patch_index]

    return full_image

N = 9  # Number of patches
W, H, C = 10, 10, 3  # Patch dimensions (WxHxC)
num_rows, num_cols = 3, 3  # Number of patches in rows and columns (3x3 patches)
patches = np.random.rand(N, W, H, C)  # Example patch data

reconstructed_image = reconstruct_image(patches, num_rows, num_cols)

Solution

  • Here's a fast pure numpy 1-liner way to do it:

    def reconstruct_image_2():
        return patches.reshape(num_rows, num_cols, W, H, C).swapaxes(1, 2).reshape(num_rows*W, num_cols*H, C)
    
    reconstructed_image_2 = reconstruct_image_2()
    
    assert np.all(reconstructed_image == reconstructed_image_2) # True
    

    Explanation: First reshape restructures your array as a "2D" array of patches, swapaxes makes your array (num_rows, W, num_cols, H, C), and finally the second and last reshape effectively concatenates the patches together in rows and columns.

    Timing comparison:

    import timeit
    
    %timeit reconstruct_image(patches, num_rows, num_cols)
    
    # 6.2 µs ± 16.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
    
    %timeit reconstruct_image_2()
    
    # 1.56 µs ± 2.57 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)