Search code examples
pythonpytorchdensenet

how to change the out_features of densenet121 model?


How to change the out_features of densenet121 model?

I am using the code below to train the model:

from torch.nn.modules.dropout import Dropout
    
class Densnet121(nn.Module):
    def __init__(self):
        super(Densnet121, self).__init__() 
        self.cnn1 = nn.Conv2d(in_channels=3 , out_channels=64 , kernel_size=3 , stride=1 )
        self.Densenet_121 = models.densenet121(pretrained=True)
        self.gap = AvgPool2d(kernel_size=2, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(1024)
        self.do1 = nn.Dropout(0.25)
        self.linear = nn.Linear(256,256)
        self.bn2 = nn.BatchNorm2d(256)
        self.do2 = nn.Dropout(0.25)
        self.output = nn.Linear(64 * 64 * 64,2)
        self.act = nn.ReLU()
        
    def densenet(self):
        for param in self.Densenet_121.parameters():
            param.requires_grad = False
        self.Densenet_121.classifier = nn.Linear(1024, 1024)
        return self.Densenet_121
    
    def forward(self, x):
        img = self.act(self.cnn1(x))
        img = self.densenet(img)      
    
        img = self.gap(img)
        img = self.bn1(img)
        img = self.do1(img)
        img = self.linear(img)
        img = self.bn2(img)
        img = self.do2(img)
        img = torch.flatten(img, 1)
        img = self.output(img)
    
        return img

When training this model, I face the following error:

RuntimeError: Given groups=1, weight of size [64, 3, 7, 7], expected input[64, 64, 62, 62] to have 3 channels, but got 64 channels instead

Solution

  • Your first conv layer outputs a tensor of shape (b, 64, h, w) while the following layer, the densenet model expects 3 channels. Hence the error that was raised:

    "expected input [...] to have 3 channels, but got 64 channels instead"

    Unfortunately, this value is hardcoded in the source of the Densenet class, see reference.

    One workaround however is to overwrite the first convolutional layer after the densenet has been initialized. Something like this should work:

    # First gather the conv layer specs
    conv = self.Densenet_121.features.conv0
    kwargs = {k: getattr(conv, k) for k in 
       ('out_channels', 'stride', 'kernel_size', 'padding', 'bias')}
    
    # overwrite with identical specs with new in_channels
    model.features.conv0 = nn.Conv2d(in_channels=64, **kwargs)    
    

    Alternatively, you can do:

    w = model.features.conv0.weight
    w.data = torch.rand(len(w), 64, *w.shape[:2])
    

    Which replaces the underlying convolutional layer weight without affecting its metadata (eg. conv.in_channels remains equal to 3), this could have side effects. So I would recommend following the first approach.