Search code examples
pytorchresnetbert-language-modeltorchvision

Bert + Resnet joint learning, pytorch model is empty after instantiation


I'm writing a simple joint model, which has two branches, one branch is a resnet50 another one is a bert. I concatenate the two outputs and pass that to a simple linear layer with 2 output neurons.

I implemented the following model :

import torch
from torch import nn
import torchvision.models as models
import torch.nn as nn
from collections import OrderedDict
from transformers import BertModel

class BertResNet(nn.Module):
    def __init__(self):
        super(BertResNet, self).__init__()
        # resnet
        resnet50 = models.resnet50(pretrained=True)
        n_inputs = resnet50.fc.in_features
        # compressed embedding space
        classifier = nn.Sequential(OrderedDict([
            ('fc1', nn.Linear(n_inputs, 512))
        ]))

        resnet50.fc = classifier # 512 out resnet 


        bert = BertModel.from_pretrained('bert-base-uncased')

        # final classification layer

        classification = nn.Linear(512 + 768, 2)
        #print(resnet50)
        #print(bert)

    def forward(self, img, text):
        res_emb = self.resnet50(img)
        bert_emb = self.bert(text)

        combined = torch.cat(res_emb,
                              bet_emb, dim=1)
        out = self.classification(combined)
        return out

But when I instantiate, I get an empty model:

bert_resnet = BertResNet()

print(bert_resnet)

Out: BertResNet()

list(bert_resnet.parameters()) also returns []


Solution

  • You never assigned the models to any attribute of the object of the BertResNet class. There are in temporary variables in the __init__ method, but once that finishes, these variables are discarded. They should be assigned to self:

    def __init__(self):
        super(BertResNet, self).__init__()
        # resnet
        self.resnet50 = models.resnet50(pretrained=True)
        n_inputs = self.resnet50.fc.in_features
        # compressed embedding space
        self.classifier = nn.Sequential(OrderedDict([
            ('fc1', nn.Linear(n_inputs, 512))
        ]))
    
        self.resnet50.fc = classifier # 512 out resnet 
    
    
        self.bert = BertModel.from_pretrained('bert-base-uncased')
    
        # final classification layer
    
        self.classification = nn.Linear(512 + 768, 2)