How do I display a PyTorch Tensor
of shape (3, 224, 224)
representing a 224x224 RGB image?
Using plt.imshow(image)
gives the error:
TypeError: Invalid dimensions for image data
Given a Tensor
representing the image, use .permute()
to put the channels as the last dimension when passing them to matplotlib:
plt.imshow(tensor_image.permute(1, 2, 0))
Note: permute
does not copy or allocate memory, and from_numpy()
doesn't either.