I am trying to use weighted mean squared error loss function for the regression task with imbalanced dataset. Basically, I have different weight assigned to each example and I am using the weighted MSE loss function. Is there a way to sample the weight tensor using TensorDataset along with input and output batch samples?
def weighted_mse_loss(inputs, targets, weights=None):
loss = (inputs - targets) ** 2
if weights is not None:
loss *= weights.expand_as(loss)
loss = torch.mean(loss)
return loss
train_dataset = torch.utils.data.TensorDataset(x_train, y_train)
weights = torch.rand(len(train_dataset))
for x, y in train_loader:
optimizer.zero_grad()
out = model(x)
loss = weighted_mse_loss(y, out, weights)
loss.backward()
If you can get the weights before creating the train dataset:
train_dataset = TensorDataset(x_train, y_train, weights)
for x, y, w in train_dataset:
...
Otherwise:
train_dataset = TensorDataset(x_train, y_train)
for (x, y), w in zip(train_dataset, weights):
...
You can also use a DataLoader but be carefull about shuffling with the second method