Search code examples
python-3.xpytorchneural-networkmean

Get the mean of last 4 layers of deep neural network for a 3D PyTorch tensor object


I am trying to get the mean of last 4 layers of BERT deep neural network.

Every hidden layer is of dimension:

outputs[1][-1]=[2,256,768] where 2 is batch size 
outputs[1][-2]=[2,256,768] where 2 is batch size 
outputs[1][-3]=[2,256,768] where 2 is batch size 
outputs[1][-4]=[2,256,768] where 2 is batch size 

I want to mean the 4 layers and output should be of same dimension [2,256,768]

Here is my code:

def __init__(self, bert_model, num_labels):
        super(BERT_CRF, self).__init__()
        self.bert = bert_model
        self.dropout = nn.Dropout(0.25)
        self.classifier = nn.Linear(768, num_labels)
        self.crf = CRF(num_labels, batch_first = True)

def forward(self, input_ids, attention_mask,  labels=None, token_type_ids=None):
    outputs = self.bert(input_ids, attention_mask=attention_mask) 
    sequence_output = torch.cat((outputs[1][-1], outputs[1][-2], outputs[1][-3], outputs[1][-4]),-1).mean(dim=[0,1,2])
    sequence_output = self.dropout(sequence_output)
    emission = self.classifier(sequence_output)

I try to do sequence_output = torch.cat((outputs[1][-1], outputs[1][-2], outputs[1][-3], outputs[1][-4]),-1).mean(dim=[0,1,2])

But it does not give me the result as expected.


Solution

  • You are looking to stack the four tensors and average the newly created dimensions. Since you are looking at the last four elements of outputs[1], you can do:

    >>> outputs[1,-4:].mean(0)
    

    This would return the average of outputs[1][-1], outputs[1][-2], outputs[1][-3], and outputs[1][-4]...