Search code examples
pythondeep-learningpytorchcomputer-visiontransformer-model

Split an image into small patches


I want to implement the Vision Transformer model, in the paper they stated that they split the input image into small patches of certain resolution, like if the image 64x64 and the patch resolution is 16x16, it will be split into 16 small patches each of resolution 16x16, so the final shape is (N,P,P,C), where N is the number of patches, P is the resolution, C is the number of channels.

What I tried so the splitting is vectorized :

def image_to_patches_fast(image, res_patch):
    
    (H, W, C) = get_image_shape(image)
    
    
    if C == 1:
        image = image.convert('RGB')
        (H, W, C) = get_image_shape(image)
                    
    P = res_patch
    N = (H*W)//(P**2)
        
    image_tensor = torchvision.transforms.PILToTensor()(image).permute(1,2,0)
    image_patches = image_tensor.view(N,P,P,C)

the function works, but the output is not as intended, as when I try to visualize the patches, there's something wrong, the patches may not well positioned or I don't know, here's an exmaple:

the input image : input image

the visualization of the output patches : patches

the function to visualize the patches :

def show_patches(patches):
    
    N,P = patches.shape[0], patches.shape[1]
       
    nrows, ncols = int(N**0.5),int(N**0.5)
    fig, axes = plt.subplots(nrows = nrows, ncols=ncols)
    for row in range(nrows):

        for col in range(ncols):

            idx = col + (row*nrows)
            
            axes[row][col].imshow(patches[idx,:,:,:])
            axes[row][col].axis("off")

    plt.subplots_adjust(left=0.1,
                    bottom=0.1,
                    right=0.9,
                    top=0.9,
                    wspace=0.1,
                    hspace=0.1)
    plt.show()
    

I tried another function to split the image, but it is slower as it uses loops, and it works as expected :

def image_to_patches_slow(image, res_patch):
    
    (H, W, C) = get_image_shape(image)
    
    
    if C == 1:
        image = image.convert('RGB')
        (H, W, C) = get_image_shape(image)
                    
    P = res_patch
    N = (H*W)//(P**2)
    
    nrows, ncols = int(N**0.5), int(N**0.5)
    
    image_tensor = torchvision.transforms.PILToTensor()(image).permute(1,2,0)
    image_patches = torch.zeros((N,P,P,C),dtype = torch.int)
    
    
    for row in range(nrows):
        s_row = row * N
        e_row = (row * N) + N
        for col in range(ncols):

            idx = col + (row*nrows)

            s_col = col*N
            e_col = (col*N) + N
                
            image_patches[idx] = image_tensor[s_row:e_row, s_col:e_col]
    
    return image_patches

it's output: slow patches

so any help as this slow version bottleneck the training.


Solution

  • This method patchifies using a single-line reshaping operation. It does this per channel.

    If the image dimensions are not divisible by the patch width, it'll crop the image by clipping off the ends. It'd be better if you replace this rudimentary cropping with something smarter like centre-cropping, resizing, or a combination (zoom then centre-crop) available in torchvision.

    Example below for a 200x200 image broken down into 50px patches.

    enter image description here

    import torchvision, torch
    
    img = torchvision.io.read_image('../image.png').permute(1, 2, 0)
    
    H, W, C = img.shape
    
    patch_width = 50
    n_rows = H // patch_width
    n_cols = W // patch_width
    
    cropped_img = img[:n_rows * patch_width, :n_cols * patch_width, :]
    
    #
    # Into patches
    # [n_rows, n_cols, patch_width, patch_width, C]
    #
    patches = torch.empty(n_rows, n_cols, patch_width, patch_width, C)
    for chan in range(C):
        patches[..., chan] = (
            cropped_img[..., chan]
            .reshape(n_rows, patch_width, n_cols, patch_width)
            .permute(0, 2, 1, 3)
        )
        
    #
    #Plot
    #
    f, axs = plt.subplots(n_rows, n_cols, figsize=(5, 5))
    
    for row_idx in range(n_rows):
        for col_idx in range(n_cols):
            axs[row_idx, col_idx].imshow(patches[row_idx, col_idx, ...] / 255)
    
    for ax in axs.flatten():
        ax.set_xticks([])
        ax.set_yticks([])
    f.subplots_adjust(wspace=0.05, hspace=0.05)