In the below coding example I can not understand why the output tensor , grid
has a shape of
3,28,280
. I understand why its 28
in height and 280
in width, but not the 3. It seems from running plt.imshow()
on all 3 28x280
arrays along axis 0 that they are identical copies since printing any 1 of these gives me the image I want.
Also I do not understand why I can pass grid
as an argument to plt.imshow()
given that it is supposed to take in a 2D array, not a 3D one as grid
clearly is.
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
train_set = torchvision.datasets.FashionMNIST(
root = './pytorch_obj_classifier/data/FashionMNIST',
train = True,
download = True,
transform = transforms.Compose([
transforms.ToTensor()
])
)
sample = next(iter(train_loader))
image,label = sample
print(image.shape)
grid = torchvision.utils.make_grid(image,padding=0, nrow=10)
print(grid.shape)
plt.figure(figsize=(15,15))
grid = np.transpose(grid,(1,2,0))
grid1 = grid[:,:,0]
grid2 = grid[:,:,1]
grid3 = grid[:,:,2]
plt.imshow(grid1,cmap = 'gray')
plt.imshow(grid2,cmap = 'gray')
plt.imshow(grid3,cmap = 'gray')
plt.imshow(grid,cmap = 'gray')
The MNIST dataset consists of grascale images. If you look at the implementation detail of torchvision.utils.make_grid
, single-channel images get their channel copied three times:
if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images
tensor = torch.cat((tensor, tensor, tensor), 1)
As for matplotlib.pyplot.imshow
it can take 2D, 3D or 4D inputs:
The image data. Supported array shapes are:
- (M, N): an image with scalar data. The data is visualized using a colormap.
- (M, N, 3): an image with RGB values (0-1 float or 0-255 int).
- (M, N, 4): an image with RGBA values (0-1 float or 0-255 int), i.e. including transparency.
Generally speaking, we wouldn't refer to dimensions but rather describe tensors by their shape (the size on each of their axes). In PyTorch, images always have three axes, and have a shape that follows: (channel, height, width)
. Even for single-channel images: considering it as a 3D tensor (1, height, width)
instead of a 2D tensor (height, width)
. This is to be consistant with cases where you have more than one channel, which is very often (cf. convolution neural networks).