Search code examples
pythonpytorchmulticlass-classificationcross-entropy

Cross-entropy loss with varying number of classes


Is there a standard/efficient way in Pytorch to handle cross-entropy loss for a classification problem where the number of classes depends on the sample?

Example: In a batch of size 3, I have:

logits1 = [0.2, 0.2, 0.6],      labels1 = [0, 0, 1]
logits2 = [0.4, 0.1, 0.1, 0.4], labels2 = [1, 0, 0, 0]
logits3 = [0.2, 0.8],           labels3 = [1, 0]

I am looking for the right way to compute cross_entropy_loss(logits,labels) on this batch.


Solution

  • Cross entropy loss is used when a single output class is being predicted. When you say the number of classes depends on the sample, I assume you mean a situation where the number of logits is different for each sample is different, but we are still in a cross entropy situation where each sample has one correct class.

    In this case you can simply pad the samples with -inf which will be ignored in the cross entropy loss calculation.

    # start with our sequences
    sequences = [
        [0.2, 0.2, 0.6],
        [0.4, 0.1, 0.1, 0.4],
        [0.2, 0.8]
    ]
    sequences = [torch.tensor(i) for i in sequences]
    
    # represent labels as class int values
    # this is required for pytorch's crossentropyloss
    labels = torch.tensor([2, 0, 0]).long()
    
    # pack sequences into a square batch
    # fill padding values with `-inf`
    padding_value = float('-inf')
    sequences_padded = torch.nn.utils.rnn.pad_sequence(sequences, batch_first=True, padding_value=padding_value)
    
    # create `CrossEntropyLoss` with `reduction='none'`
    # this makes the loss return the value for each input (ie no averaging) 
    # so we can compare values
    loss = nn.CrossEntropyLoss(reduction='none')
    
    # compute loss on individual sequences without padding
    l1 = torch.stack([loss(sequences[i], labels[i]) for i in range(labels.shape[0])])
    
    # compute loss on padded sequences
    l2 = loss(sequences_padded, labels)
    
    # check values match
    assert torch.allclose(l1, l2)
    

    This works because cross entropy computes exp(i) for all values in the input, and exp(-inf) evals to 0. Because of this, the padding values have no impact on the output loss.