I want to implement the Vision Transformer model, in the paper they stated that they split the input image into small patches of certain resolution, like if the image 64x64 and the patch resolution is 16x16, it will be split into 16 small patches each of resolution 16x16, so the final shape is (N,P,P,C), where N is the number of patches, P is the resolution, C is the number of channels.
What I tried so the splitting is vectorized :
def image_to_patches_fast(image, res_patch):
(H, W, C) = get_image_shape(image)
if C == 1:
image = image.convert('RGB')
(H, W, C) = get_image_shape(image)
P = res_patch
N = (H*W)//(P**2)
image_tensor = torchvision.transforms.PILToTensor()(image).permute(1,2,0)
image_patches = image_tensor.view(N,P,P,C)
the function works, but the output is not as intended, as when I try to visualize the patches, there's something wrong, the patches may not well positioned or I don't know, here's an exmaple:
the visualization of the output patches :
the function to visualize the patches :
def show_patches(patches):
N,P = patches.shape[0], patches.shape[1]
nrows, ncols = int(N**0.5),int(N**0.5)
fig, axes = plt.subplots(nrows = nrows, ncols=ncols)
for row in range(nrows):
for col in range(ncols):
idx = col + (row*nrows)
axes[row][col].imshow(patches[idx,:,:,:])
axes[row][col].axis("off")
plt.subplots_adjust(left=0.1,
bottom=0.1,
right=0.9,
top=0.9,
wspace=0.1,
hspace=0.1)
plt.show()
I tried another function to split the image, but it is slower as it uses loops, and it works as expected :
def image_to_patches_slow(image, res_patch):
(H, W, C) = get_image_shape(image)
if C == 1:
image = image.convert('RGB')
(H, W, C) = get_image_shape(image)
P = res_patch
N = (H*W)//(P**2)
nrows, ncols = int(N**0.5), int(N**0.5)
image_tensor = torchvision.transforms.PILToTensor()(image).permute(1,2,0)
image_patches = torch.zeros((N,P,P,C),dtype = torch.int)
for row in range(nrows):
s_row = row * N
e_row = (row * N) + N
for col in range(ncols):
idx = col + (row*nrows)
s_col = col*N
e_col = (col*N) + N
image_patches[idx] = image_tensor[s_row:e_row, s_col:e_col]
return image_patches
so any help as this slow version bottleneck the training.
This method patchifies using a single-line reshaping operation. It does this per channel.
If the image dimensions are not divisible by the patch width, it'll crop the image by clipping off the ends. It'd be better if you replace this rudimentary cropping with something smarter like centre-cropping, resizing, or a combination (zoom then centre-crop) available in torchvision
.
Example below for a 200x200 image broken down into 50px patches.
import torchvision, torch
img = torchvision.io.read_image('../image.png').permute(1, 2, 0)
H, W, C = img.shape
patch_width = 50
n_rows = H // patch_width
n_cols = W // patch_width
cropped_img = img[:n_rows * patch_width, :n_cols * patch_width, :]
#
# Into patches
# [n_rows, n_cols, patch_width, patch_width, C]
#
patches = torch.empty(n_rows, n_cols, patch_width, patch_width, C)
for chan in range(C):
patches[..., chan] = (
cropped_img[..., chan]
.reshape(n_rows, patch_width, n_cols, patch_width)
.permute(0, 2, 1, 3)
)
#
#Plot
#
f, axs = plt.subplots(n_rows, n_cols, figsize=(5, 5))
for row_idx in range(n_rows):
for col_idx in range(n_cols):
axs[row_idx, col_idx].imshow(patches[row_idx, col_idx, ...] / 255)
for ax in axs.flatten():
ax.set_xticks([])
ax.set_yticks([])
f.subplots_adjust(wspace=0.05, hspace=0.05)