Search code examples
pytorchtorchtorchvision

model with CrossEntropyLoss criterion doesnt apply softmax pytorch


I am using nn.CrossEntropyLoss() in as my criterion in a model that I am developing. The problem that I am having is that the model outputs a vector of size (batchsize, #classes) when it is supposed to output a (batchsize) vector.

Isn't CrossEntropyLoss supposed to apply LogSoftmax?

Here's my Dataset:

class DatasetPlus(Dataset):
    def __init__(self, root_img, root_data, width, hight, transform=None):
        self.root_img = root_img
        self.root_data = root_data
        self.width = width
        self.hight = hight
        self.transform = transform
        # labels are stored in a csv file
        self.labels = pd.read_csv(self.root_data)
        self.imgs = [image for image in sorted(
            os.listdir(self.root_img)) if image[-4:] == '.jpg']
        self.len = len(self.imgs)

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        img_name = self.imgs[idx]
        img_path = os.path.join(self.root_img, img_name)
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32)
        img = cv2.resize(img, (self.width, self.hight), cv2.INTER_AREA)
        img = np.array(img) / 255.0

        if self.transform is not None:
            img = self.transform(img)

        img_id = int(img_name[6:-4])
        label = self.labels.where(self.labels['ID'] == img_id)['Label'].dropna().to_numpy()[0]

        label = torch.tensor(label, dtype=torch.float32)

        return img, label

Here is my model:

class Net(nn.Module):
    def __init__(self, h, w):
        super().__init__()
        nw = (((w - 4) // 2) -4) // 2
        nh = (((h - 4) // 2) -4) // 2
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * nh * nw, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 3)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = (self.fc3(x))
        return x

Here's my training code:

model = Net(224, 224)

trainloader = DataLoader(ds, batch_size=4, shuffle=True)

criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=1e-4)

def train_model(epochs):
    for epoch in range(epochs): 
        losses = 0.0 
        for i, data in enumerate(trainloader, 0):
            optimizer.zero_grad()
            img, label = data
            yhat = model(img)
            loss = criterion(yhat, label)
            loss.backward()
            optimizer.step()
            losses += loss.item()
            # if i % 5 == 99:
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {losses:.3f}')
            losses = 0.0

train_model(5)

I have explained the problem but here's the error anyways:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[9], line 1
----> 1 train_model(5)

Cell In[8], line 13, in train_model(epochs)
     11 print(yhat.size())
     12 print(label.size())
---> 13 loss = criterion(yhat, label)
     14 loss.backward()
     15 optimizer.step()

File c:\Users\Yasamin\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\nn\modules\module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File c:\Users\Yasamin\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\nn\modules\loss.py:720, in BCEWithLogitsLoss.forward(self, input, target)
    719 def forward(self, input: Tensor, target: Tensor) -> Tensor:
--> 720     return F.binary_cross_entropy_with_logits(input, target,
    721                                               self.weight,
    722                                               pos_weight=self.pos_weight,
    723                                               reduction=self.reduction)

File c:\Users\Yasamin\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\nn\functional.py:3160, in binary_cross_entropy_with_logits(input, target, weight, size_average, reduce, reduction, pos_weight)
   3157     reduction_enum = _Reduction.get_enum(reduction)
   3159 if not (target.size() == input.size()):
-> 3160     raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), input.size()))
   3162 return torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction_enum)

ValueError: Target size (torch.Size([4])) must be the same as input size (torch.Size([4, 3]))

And finally, these are the outputs and the labels that raise this error:

yhat=
tensor([[ 0.0097,  0.0184, -0.1236],
        [ 0.0020,  0.0135, -0.1324],
        [ 0.0095,  0.0136, -0.1261],
        [ 0.0027,  0.0176, -0.1285]], grad_fn=<AddmmBackward0>)
torch.Size([4, 3])

label=
tensor([2., 1., 0., 2.])
torch.Size([4])

Solution

  • from what I found out, CrossEntropyLoss works in two ways.

    If you pass it Long labels, it treats the labels as integer class labels and the shape of (batchsize) is correct.

    But if you pass CrossEntropyLoss labels of type Float (as I have in my code) CrossEntropyLoss therefore treats your labels as probabilistic (“soft”) labels and expects labels to have shape (nBatch, #classes), that is, to have the same shape as yhat.

    So to fix the error, label should be converted to Long, before being passed to CrossEntropyLoss (or set it to int64 when creating the tensor)

    Also it is worth noting that labels should be from zero to )#classes -1) for CrossEntropyLoss to operate correctly.