Search code examples
pytorchpytorch-dataloader

Weighted mean squared error pytorch with sample weights regression


I am trying to use weighted mean squared error loss function for the regression task with imbalanced dataset. Basically, I have different weight assigned to each example and I am using the weighted MSE loss function. Is there a way to sample the weight tensor using TensorDataset along with input and output batch samples?

def weighted_mse_loss(inputs, targets, weights=None):
    loss = (inputs - targets) ** 2
    if weights is not None:
        loss *= weights.expand_as(loss)
    loss = torch.mean(loss)
    return loss

train_dataset = torch.utils.data.TensorDataset(x_train, y_train)
weights = torch.rand(len(train_dataset))

for x, y in train_loader:
    optimizer.zero_grad()    
    out = model(x)    
    loss = weighted_mse_loss(y, out, weights)                             
    loss.backward()

Solution

  • If you can get the weights before creating the train dataset:

    train_dataset = TensorDataset(x_train, y_train, weights)
    for x, y, w in train_dataset:
        ...
    

    Otherwise:

    train_dataset = TensorDataset(x_train, y_train)
    for (x, y), w in zip(train_dataset, weights):
        ...
    

    You can also use a DataLoader but be carefull about shuffling with the second method