Search code examples
pythonmachine-learningpytorchneural-networkcross-entropy

CrossEntropyLoss using weights gives RuntimeError: expected scalar type Float but found Long neural network


I am using a Feedforward neural network for a classification task with 4 classes. The classes are imbalanced and hence, I want to use a weight with the CrossEntropyLoss as mentioned here.

Here is my neural network:

class FeedforwardNeuralNetModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(FeedforwardNeuralNetModel, self).__init__()
        # Linear function
        self.fc1 = nn.Linear(input_dim, hidden_dim) 
        # Non-linearity
        self.relu = nn.ReLU()
        # Linear function (readout)
        self.fc2 = nn.Linear(hidden_dim, output_dim)  

    def forward(self, x):
        # Linear function
        out = self.fc1(x)
        # Non-linearity
        out = self.relu(out)
        # Linear function (readout)
        out = self.fc2(out)
        return out

And here is how I am using it:

learning_rate = 0.1
batch_size = 64

def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(X)
        print(pred, y.long())
        loss = loss_fn(pred, y.long())
        print(loss)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

ffn_model = FeedforwardNeuralNetModel(input_dim=32, hidden_dim=128, output_dim=4)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(ffn_model.parameters(), lr=0.1)

num_epochs = 10
for t in range(num_epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_dataloader, ffn_model, loss_fn, optimizer)
print("Done!")

The above works fine. However, when I try to use weights in the CrossEntropyLoss as follows:

loss_fn = nn.CrossEntropyLoss(weight=torch.tensor([1, 1000, 1000, 1000]))

It gives the following error:

RuntimeError                              Traceback (most recent call last)
Cell In[91], line 13
     11 for t in range(num_epochs):
     12     print(f"Epoch {t+1}\n-------------------------------")
---> 13     train_loop(train_dataloader, ffn_model, loss_fn, optimizer)
     14 print("Done!")

Cell In[90], line 10, in train_loop(dataloader, model, loss_fn, optimizer)
      8 pred = model(X)
      9 print(pred, y.long())
---> 10 loss = loss_fn(pred, y.long())
     11 print(loss)
     13 # Backpropagation

File ~/miniconda3/envs/gnn/lib/python3.9/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/miniconda3/envs/gnn/lib/python3.9/site-packages/torch/nn/modules/loss.py:1174, in CrossEntropyLoss.forward(self, input, target)
   1173 def forward(self, input: Tensor, target: Tensor) -> Tensor:
-> 1174     return F.cross_entropy(input, target, weight=self.weight,
   1175                            ignore_index=self.ignore_index, reduction=self.reduction,
   1176                            label_smoothing=self.label_smoothing)

File ~/miniconda3/envs/gnn/lib/python3.9/site-packages/torch/nn/functional.py:3029, in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)
   3027 if size_average is not None or reduce is not None:
   3028     reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 3029 return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)

RuntimeError: expected scalar type Float but found Long

Solution

  • You need to pass weights as float when creating your loss fn:

    from torch import nn
    import torch
    loss_fn = nn.CrossEntropyLoss(weight=torch.tensor([1., 1000., 1000., 1000.]))
    
    inputs = torch.randn(100, 4)
    labels = torch.randint(0, 3, size=(100,))
    loss_fn(inputs, labels)