Search code examples

In pytorch data parallel mode, how to use the global tensor?

In this example, I wish the z_proto could be global for different GPUs. However, in the data parallel mode, it is split into different GPUs as well. How to solve such a problem? Thank you.

class SequencePrototypeTokenClassification(nn.Module):
    def __init__(self,seq_model, label_num):
        super(SequencePrototypeTokenClassification, self).__init__()
        self.seq_model = seq_model
        self.label_num = label_num

    def forward(self, input_ids, token_type_ids, attention_mask, labels, z_proto, n_query, target_inds):
        z, _ = self.seq_model(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
        z_dim = z.size(-1)
        zq = z.squeeze().view(-1, z_dim)
        dists = euclidean_dist(zq, z_proto)
        log_p_y = F.log_softmax(-dists, dim=1).view(-1, self.label_num)
        loss_val = -log_p_y.gather(1, self.target_inds).squeeze().view(-1).mean()
        _, y_hat = log_p_y.max(1)

        return loss_val, y_hat


  • It turns out the DataParallel would only replicate the nn.Parameter of the nn.Module. So I random initialized a nn.Parameter named z_proto in the module and copy the value of tensor z_proto into the parameter. Then the parameter is replicated into 4 GPUs.