Search code examples
neural-networkclassificationconv-neural-networkpytorchtransfer-learning

Can't replace classifier on Densenet121 in pytorch


I am trying to do some transfer learning using this github DenseNet121 model (https://github.com/gaetandi/cheXpert.git). I'm running into issues resizing the classification layer from 14 to 2 outputs.

Relevant part of the github code is:

class DenseNet121(nn.Module):
    """Model modified.
    The architecture of our model is the same as standard DenseNet121
    except the classifier layer which has an additional sigmoid function.
    """
    def __init__(self, out_size):
        super(DenseNet121, self).__init__()
        self.densenet121 = torchvision.models.densenet121(pretrained=True)
        num_ftrs = self.densenet121.classifier.in_features
        self.densenet121.classifier = nn.Sequential(
            nn.Linear(num_ftrs, out_size),
            nn.Sigmoid()
        )
def forward(self, x):
    x = self.densenet121(x)
    return x

I load and init with:

# initialize and load the model
model = DenseNet121(nnClassCount).cuda()
model = torch.nn.DataParallel(model).cuda()
modeldict = torch.load("model_ones_3epoch_densenet.tar")
model.load_state_dict(modeldict['state_dict'])

It looks like DenseNet doesn't split layers up into children so model = nn.Sequential(*list(modelRes.children())[:-1]) won't work.

model.classifier = nn.Linear(1024, 2) seems to work on default DenseNets, but with the modified classifier (additional sigmoid function) here it ends up just adding an additional classifier layer without replacing the original.

I've tried

model.classifier = nn.Sequential(
    nn.Linear(1024, dset_classes_number), 
    nn.Sigmoid()
)

But am having the same added instead of replaced classifier issue:

...
      )
      (classifier): Sequential(
        (0): Linear(in_features=1024, out_features=14, bias=True)
        (1): Sigmoid()
      )
    )
  )
  (classifier): Sequential(
    (0): Linear(in_features=1024, out_features=2, bias=True)
    (1): Sigmoid()
  )
)

Solution

  • If you want to replace the classifier inside densenet121 that is a member of your model you need to assign

    model.densenet121.classifier = nn.Sequential(...)