Search code examples
pythonmatplotlibpytorch

How do I display a single image in PyTorch?


How do I display a PyTorch Tensor of shape (3, 224, 224) representing a 224x224 RGB image? Using plt.imshow(image) gives the error:

TypeError: Invalid dimensions for image data


Solution

  • Given a Tensor representing the image, use .permute() to put the channels as the last dimension when passing them to matplotlib:

    plt.imshow(tensor_image.permute(1, 2, 0))
    

    Note: permute does not copy or allocate memory, and from_numpy() doesn't either.