How to change the out_features
of densenet121
model?
I am using the code below to train the model:
from torch.nn.modules.dropout import Dropout
class Densnet121(nn.Module):
def __init__(self):
super(Densnet121, self).__init__()
self.cnn1 = nn.Conv2d(in_channels=3 , out_channels=64 , kernel_size=3 , stride=1 )
self.Densenet_121 = models.densenet121(pretrained=True)
self.gap = AvgPool2d(kernel_size=2, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(1024)
self.do1 = nn.Dropout(0.25)
self.linear = nn.Linear(256,256)
self.bn2 = nn.BatchNorm2d(256)
self.do2 = nn.Dropout(0.25)
self.output = nn.Linear(64 * 64 * 64,2)
self.act = nn.ReLU()
def densenet(self):
for param in self.Densenet_121.parameters():
param.requires_grad = False
self.Densenet_121.classifier = nn.Linear(1024, 1024)
return self.Densenet_121
def forward(self, x):
img = self.act(self.cnn1(x))
img = self.densenet(img)
img = self.gap(img)
img = self.bn1(img)
img = self.do1(img)
img = self.linear(img)
img = self.bn2(img)
img = self.do2(img)
img = torch.flatten(img, 1)
img = self.output(img)
return img
When training this model, I face the following error:
RuntimeError: Given groups=1, weight of size [64, 3, 7, 7], expected input[64, 64, 62, 62] to have 3 channels, but got 64 channels instead
Your first conv layer outputs a tensor of shape (b, 64, h, w)
while the following layer, the densenet model expects 3 channels. Hence the error that was raised:
"expected input [...] to have 3 channels, but got 64 channels instead"
Unfortunately, this value is hardcoded in the source of the Densenet class, see reference.
One workaround however is to overwrite the first convolutional layer after the densenet has been initialized. Something like this should work:
# First gather the conv layer specs
conv = self.Densenet_121.features.conv0
kwargs = {k: getattr(conv, k) for k in
('out_channels', 'stride', 'kernel_size', 'padding', 'bias')}
# overwrite with identical specs with new in_channels
model.features.conv0 = nn.Conv2d(in_channels=64, **kwargs)
Alternatively, you can do:
w = model.features.conv0.weight
w.data = torch.rand(len(w), 64, *w.shape[:2])
Which replaces the underlying convolutional layer weight without affecting its metadata (eg. conv.in_channels
remains equal to 3
), this could have side effects. So I would recommend following the first approach.