Search code examples
pythonmachine-learningcomputer-visiontorch

Pytorch error: ValueError: pic should be 2/3 dimensional. Got 4 dimensions


Trying to follow this tutorial here. Although when I select my content image and style image when I try to use the imshow() function I am getting this error:

ValueError: pic should be 2/3 dimensional. Got 4 dimensions.

Using google I have not been able to really find any remedy to this problem.

Here is my code:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
import matplotlib.pyplot as plt
import torchvision.transforms as transforms 
import torchvision.models as models
import copy
import numpy as np

# This detects if cuda is available for GPU training otherwise will use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# Desired size of the output image
imsize = 512 if torch.cuda.is_available() else 256
print(imsize)

# Helper function
def image_loader(image_name, imsize):
    # Scale the imported image and transform it into a torch tensor
    loader = transforms.Compose([transforms.Resize(imsize), transforms.ToTensor()])
    image = Image.open(image_name)
    # Fake batch dimension required to fit network's input dimension
    image = loader(image).unsqueeze(0)
    return image.to(device, torch.float)

# Helper function to show the tensor as a PIL image
def imshow(tensor, title=None):
    unloader = transforms.ToPILImage()
    image = tensor.cpu().clone()
    image = unloader(image)
    plt.imshow(image)
    if title is not None:
        plt.title(title)
    plt.pause(0.001) # Pause so that the plots are updated

# Loading of images
image_directory = './images/'
style_img = image_loader(image_directory + "pb.jpg", imsize)
content_img = image_loader(image_directory + "content.jpg", imsize)
assert style_img.size() == content_img.size(), "we need to import style and content images of the same size"

plt.figure()
imshow(style_img, title='style image')

Any suggestions would be really helpful.

Here is the style and content image for reference:

enter image description here

enter image description here


Solution

  • matplotlib.pyplot expects either 2D (grayscale, dimensions=(W,H)) or 3D (colored, dimensions = (W,H,color channel)) in the imshow-function.

    You probably still have the batchsize as a first dimension in your tensor, because in your code you do:

    # Fake batch dimension required to fit network's input dimension
    image = loader(image).unsqueeze(0)
    

    which adds this first dimensions. If so, try either to use:

    plt.imshow(np.squeeze(image))
    

    or

    plt.imshow(image[0])