Suppose there are over one thousand tasks in the multi-task deep learning. More than a thousand columns of labels. Each task (column) has a specific weight in this case. It would take such long time to loop over each task to calculate the sum of loss using the following code snippet.
criterion = nn.MSELoss()
outputs = model(inputs)
loss = torch.tensor(0.0).to(device)
for j, w in enumerate(weights):
# mask keeping labeled molecules for each task
mask = labels[:, j] >= 0.0
if len(labels[:, j][mask]) != 0:
# the loss is the sum of each task/target loss.
# there are labeled samples for this task, so we add it's loss
loss += criterion(outputs[j][mask], labels[:, j][mask].view(-1, 1)) * w
This dataset was quite small. The dataset has 10K rows and 1024 columns and the labels are a 10K * 160 sparse matrix. Each of those 160 columns is one task. Batch size is 32. Below are the shapes of outputs, labels, weights:
len(outputs[0]), len(outputs)
(32, 160)
torch.Size([32, 160])
But what I really want to try is one dataset which has over 1M rows and 1024 features and over 10K labels. The labels are sparse of course.
Thanks for you suggestions and code, Shai. I modified the code a little bit as follows, but the loss was the same as your code.
all_out =, -1).T
all_mask = labels != -100.0
err = (all_out - labels) ** 2 # raw L2
err = all_mask * err # mask only the relevant entries in the err
mask_nums = all_mask.sum(axis=0)
err = err * weights[None, :] # weight each task
err = err / mask_nums[None, :]
err[err != err] = torch.tensor([0.0], requires_grad=True).to(device) # replace nan to 0.0
loss = err.sum()
A newly raised question is the loss can't get back propagated. Only the loss of the first batch was calculated. The following batches got a loss of 0.0.
Epoch: [1/20], Step: [1/316], Loss: 4.702103614807129
Epoch: [1/20], Step: [2/316], Loss: 0.0
Epoch: [1/20], Step: [3/316], Loss: 0.0
Epoch: [1/20], Step: [4/316], Loss: 0.0
Epoch: [1/20], Step: [5/316], Loss: 0.0
The loss was 0 and outputs was 32* 160 of nan after the first batch.
How is your loss different than:
all_out =[o_[:, None] for o_ in outputs], dim=1) # all_out has shape 32x160
all_mask = labels >= 0
err = (all_out - labels) ** 2 # raw L2
err = all_mask * err # mask only the relevant entries in the err
err = err * weights[None, :] # weight each task
err = err.sum()
There might be a slight issue here with the summation - you might need to weight by the number of 1
s in each column of all_mask