Search code examples
pythonpython-3.xpytorchtorchvision

Why does torchvision.utils.make_grid() return copies of the wanted grid?


In the below coding example I can not understand why the output tensor , grid has a shape of 3,28,280. I understand why its 28 in height and 280 in width, but not the 3. It seems from running plt.imshow() on all 3 28x280 arrays along axis 0 that they are identical copies since printing any 1 of these gives me the image I want. Also I do not understand why I can pass grid as an argument to plt.imshow() given that it is supposed to take in a 2D array, not a 3D one as grid clearly is.

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

train_set = torchvision.datasets.FashionMNIST(
    root = './pytorch_obj_classifier/data/FashionMNIST',
    train = True,
    download = True,
    transform = transforms.Compose([
            transforms.ToTensor()
    ])
)
sample = next(iter(train_loader))
image,label = sample
print(image.shape)

grid = torchvision.utils.make_grid(image,padding=0, nrow=10)
print(grid.shape)

plt.figure(figsize=(15,15))
grid = np.transpose(grid,(1,2,0))
grid1 = grid[:,:,0]
grid2 = grid[:,:,1]
grid3 = grid[:,:,2]
plt.imshow(grid1,cmap = 'gray')
plt.imshow(grid2,cmap = 'gray')
plt.imshow(grid3,cmap = 'gray')
plt.imshow(grid,cmap = 'gray')


Solution

  • The MNIST dataset consists of grascale images. If you look at the implementation detail of torchvision.utils.make_grid, single-channel images get their channel copied three times:

    if tensor.dim() == 4 and tensor.size(1) == 1:  # single-channel images
        tensor = torch.cat((tensor, tensor, tensor), 1)
    

    As for matplotlib.pyplot.imshow it can take 2D, 3D or 4D inputs:

    The image data. Supported array shapes are:

    • (M, N): an image with scalar data. The data is visualized using a colormap.
    • (M, N, 3): an image with RGB values (0-1 float or 0-255 int).
    • (M, N, 4): an image with RGBA values (0-1 float or 0-255 int), i.e. including transparency.

    Generally speaking, we wouldn't refer to dimensions but rather describe tensors by their shape (the size on each of their axes). In PyTorch, images always have three axes, and have a shape that follows: (channel, height, width). Even for single-channel images: considering it as a 3D tensor (1, height, width) instead of a 2D tensor (height, width). This is to be consistant with cases where you have more than one channel, which is very often (cf. convolution neural networks).