Search code examples
pythonpytorchconv-neural-networkloss-function

custom function giving error too many indices for tensor of dimension 1


I am calculating loss for multiclass(7) classification program using pytorch.

class AFL(nn.Module):
   
    def __init__(self, delta=0.7, gamma=2., epsilon=1e-07):
        super(AFL, self).__init__()
        self.delta = delta
        self.gamma = gamma
        self.epsilon = epsilon

    def forward(self, y_pred, y_true):
        #y_pred=y_pred.size()[1]
        print(y_pred.shape) #[32,7]
        print(y_true.shape) #[32]
        y_pred = torch.clamp(y_pred, self.epsilon, 1. - self.epsilon)
        cross_entropy = np.empty(y_pred.shape)
        for i in range(len(y_pred)):
            for j in range(len(y_pred[i])):
                cross_entropy[i][j] = -y_true * torch.log(y_pred[i][j])
        #cross_entropy = -y_true * torch.log(y_pred[0][0]) #here i want to calculate cross_entropy for for each class
        
    # Calculate losses separately for each class, only suppressing background class
        back_ce = torch.pow(1 - y_pred[:,0], self.gamma) * cross_entropy[:,0]
        back_ce =  (1 - self.delta) * back_ce

        fore_ce = cross_entropy[:,1,:,:]
        fore_ce = self.delta * fore_ce

        loss = torch.mean(torch.sum(torch.stack([back_ce, fore_ce], axis=-1), axis=-1))
        return loss

I want to calculate back_ce for each class separately , but getting error as;

 back_ce = torch.pow(1 - y_pred[:,0], self.gamma) * cross_entropy[:,0]
IndexError: too many indices for tensor of dimension 1

Can anyone please tell where i am doing wrong. size of y_pred and y_true is mentioned.


Solution

  • Here is the AFL for multi-class with multiple common and rare classes.

    class AsymmetricFocalLoss(nn.Module):
        """For Imbalanced datasets
        Parameters
        ----------
        delta : float, optional
            controls weight given to false positive and false negatives, by default 0.25
        gamma : float, optional
            Focal Tversky loss' focal parameter controls degree of down-weighting of easy examples, by default 2.0
        epsilon : float, optional
            clip values to prevent division by zero error
        common : list, required
            a list of common class indices
        rare : list, required
            a list of rare class indices
        """
        def __init__(self, common, rare, delta=0.7, gamma=2., epsilon=1e-07):
            super(AsymmetricFocalLoss, self).__init__()
            self.delta = delta
            self.gamma = gamma
            self.epsilon = epsilon
            self.common = common
            self.rare = rare
    
        def forward(self, y_pred, y_labels):
            # assume y_pred contain probabilities (batch_size_ n_class)
            # y_labels contain integer class lables (batch_size, )
    
            # convert one-hot
            y_true = torch.zeros_like(y_pred)
            for i,j in enumerate(y_labels):y_true[i, j]=1
    
            # clamp
            y_pred = torch.clamp(y_pred, self.epsilon, 1. - self.epsilon)
            cross_entropy = -y_true * torch.log(y_pred)
            #print(f'{cross_entropy.shape=}\n{cross_entropy=}')
        
            # Calculate losses separately for each class,
            all_ce=[]
    
            for c in self.common:
                back_ce = (1 - self.delta) * (torch.pow(1 - y_pred[:,c], self.gamma) * cross_entropy[:,c])
                all_ce.append(back_ce)
    
            for r in self.rare:
                fore_ce=self.delta * cross_entropy[:,r]
                all_ce.append(fore_ce)
    
            loss_stack = torch.stack(all_ce, axis=-1)
            #print(f'{loss_stack.shape=}\n{loss_stack=}')
    
            loss_sum=torch.sum(loss_stack, axis=-1)
            #print(f'{loss_sum.shape=}\n{loss_sum=}')
    
            loss = torch.mean(loss_sum)
    
            return loss
    

    To use this,

    batch_size = 5
    n_class = 7
    
    y_pred = torch.softmax( torch.rand((batch_size, n_class)), dim=-1)
    y_labels = torch.randint(0, n_class, size=(batch_size,))
    print(f'{y_pred=}\n{y_labels=}')
    
    lossF = AsymmetricFocalLoss(common = [0,2,4,6], rare = [1,3,5])
    loss = lossF(y_pred, y_labels)
    
    print(f'{loss=}')
    

    output:

    """
    y_pred=tensor([[0.1955, 0.1455, 0.0976, 0.1869, 0.1043, 0.1173, 0.1529],
            [0.1613, 0.1635, 0.1121, 0.1290, 0.1571, 0.0993, 0.1777],
            [0.0978, 0.1340, 0.1025, 0.1993, 0.2197, 0.1041, 0.1425],
            [0.1371, 0.1113, 0.1771, 0.1560, 0.0897, 0.1554, 0.1734],
            [0.1960, 0.1890, 0.1403, 0.1076, 0.1714, 0.1079, 0.0878]])
    y_labels=tensor([0, 3, 2, 5, 3])
    loss=tensor(1.0328)
    """