Search code examples
pytorchtransfer-learningunet-neural-networkfine-tuning

Fine tuning freezing weights nnUNet


Good morning, I've followed the instructions in this github issue:

https://github.com/MIC-DKFZ/nnUNet/issues/1108

to fine-tune an nnUNet model (pyTorch) on a pre-trained one, but this method retrain all weights, and i would like to freeze all weigths and retrain only the last layer's weights, changing the number of segmentation classes from 3 to 1. Do you know a way to do that? Thank you in advance


Solution

  • To freeze the weights you need to set parameter.requires_grad = False.

    Example:

    from nnunet.network_architecture.generic_UNet import Generic_UNet
    
    model = Generic_UNet(input_channels=3, base_num_features=64, num_classes=4, num_pool=3)
    
    for name, parameter in model.named_parameters():
        if 'seg_outputs' in name:
            print(f"parameter '{name}' will not be freezed")
            parameter.requires_grad = True
        else:
            parameter.requires_grad = False
    

    To check parameter names you can use print:

    print(model)
    

    which produces:

    Generic_UNet(
      (conv_blocks_localization): ModuleList(
        (0): Sequential(
          (0): StackedConvLayers(
            (blocks): Sequential(
              (0): ConvDropoutNormNonlin(
                (conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (instnorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                (lrelu): LeakyReLU(negative_slope=0.01, inplace=True)
              )
            )
          )
          (1): StackedConvLayers(
            (blocks): Sequential(
              (0): ConvDropoutNormNonlin(
                (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (instnorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                (lrelu): LeakyReLU(negative_slope=0.01, inplace=True)
              )
            )
          )
        )
      )
      (conv_blocks_context): ModuleList(
        (0): StackedConvLayers(
          (blocks): Sequential(
            (0): ConvDropoutNormNonlin(
              (conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (dropout): Dropout2d(p=0.5, inplace=True)
              (instnorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (lrelu): LeakyReLU(negative_slope=0.01, inplace=True)
            )
            (1): ConvDropoutNormNonlin(
              (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (dropout): Dropout2d(p=0.5, inplace=True)
              (instnorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (lrelu): LeakyReLU(negative_slope=0.01, inplace=True)
            )
          )
        )
        (1): Sequential(
          (0): StackedConvLayers(
            (blocks): Sequential(
              (0): ConvDropoutNormNonlin(
                (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (dropout): Dropout2d(p=0.5, inplace=True)
                (instnorm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                (lrelu): LeakyReLU(negative_slope=0.01, inplace=True)
              )
            )
          )
          (1): StackedConvLayers(
            (blocks): Sequential(
              (0): ConvDropoutNormNonlin(
                (conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (dropout): Dropout2d(p=0.5, inplace=True)
                (instnorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                (lrelu): LeakyReLU(negative_slope=0.01, inplace=True)
              )
            )
          )
        )
      )
      (td): ModuleList(
        (0): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
      )
      (tu): ModuleList(
        (0): Upsample()
      )
      (seg_outputs): ModuleList(
        (0): Conv2d(64, 4, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    
    

    Or you can visualize your network with netron:

    https://github.com/lutzroeder/netron