Search code examples
pythonpytorch

How to modify "in_channel" of the firstly layer CNN in the timm model?


everyone. I hope to train a CV model in the timm library on my dataset. Due to the shape of the input data is (batch_size, 15, 224, 224), I need to modify the "in_channel" of the first CNN layer of different CV models. I try different methods but still fail. Could you help me solve this problem? Thanks!

import torch
import torch.nn as nn

import timm

class FrequencyModel(nn.Module):

    def __init__(
        self, 
        in_channels = 6, 
        output = 9, 
        model_name = 'resnet200d', 
        pretrained = False
        ):

        super(FrequencyModel, self).__init__()
        
        self.in_channels = in_channels
        self.output = output
        self.model_name = model_name
        self.pretrained = pretrained

        self.m = timm.create_model(self.model_name, pretrained=self.pretrained, num_classes=output)

        for layer in self.m.modules():
            if(isinstance(layer,nn.Conv2d)):
                layer.in_channels = self.in_channels
                break

    def forward(self,x):
        
        out=self.m(x)

        return out

if __name__ == "__main__":
    
    x = torch.randn((8, 15, 224, 224))

    model=FrequencyModel(
        in_channels = 15, 
        output = 9, 
        model_name = 'resnet200d', 
        pretrained = False
    )
    print(model)
    print(model(x).shape)

The error is:

RuntimeError: Given groups=1, weight of size [32, 3, 3, 3], expected input[8, 15, 224, 224] to have 3 channels, but got 15 channels instead

I hope I can test different CV model easily but not adjust it one by one.


Solution

  • Do you want to change the first conv2d layer of resnet200d?

    Let me point out a few things that went wrong here.

    1. Changing only the value of in_channels does not change the shape of the weight.
    2. resnet200d is composed of sequential layers and has a conv2d layer in them. So, you cannot access conv2d with a for statement like the code above. Use the apply method for a recursive approach.
    3. If you want to actually change the layer, use ._modules['module name']. If you access a layer with m.modules(), the layer does not change because deepcopy occurs. Thus, get both name and module using model.named_modules()

    So you probably want to change it like this:

    import torch
    import torch.nn as nn
    import timm
    
    x = torch.randn((8, 15, 224, 224))
    m = timm.create_model('resnet200d', pretrained=False, num_classes=9)
    m._modules['conv1']._modules['0'] = nn.Conv2d(15, 32, 3, stride=2, padding=1, bias=False)
    print(m)
    print(model(x).shape)
    

    More generally, you can change like this:

    change_first_layer function that changes the in_channels of the first conv2d layer to 15 for all models.

    import torch
    import torch.nn as nn
    import timm
    
    def change_first_layer(m):
      for name, child in m.named_children():
        if isinstance(child, nn.Conv2d):
          kwargs = {
              'out_channels': child.out_channels,
              'kernel_size': child.kernel_size,
              'stride': child.stride,
              'padding': child.padding,
              'bias': False if child.bias == None else True
          }
          m._modules[name] = nn.Conv2d(15, **kwargs)
          return True
        else:
          if(change_first_layer(child)):
            return True
      return False
    
    x = torch.randn((8, 15, 224, 224))
    m = timm.create_model('resnet200d', pretrained=False, num_classes=9)
    change_first_layer(m)
    print(m)
    print(m(x).shape)