Good morning, I've followed the instructions in this github issue:
https://github.com/MIC-DKFZ/nnUNet/issues/1108
to fine-tune an nnUNet model (pyTorch) on a pre-trained one, but this method retrain all weights, and i would like to freeze all weigths and retrain only the last layer's weights, changing the number of segmentation classes from 3 to 1. Do you know a way to do that? Thank you in advance
To freeze the weights you need to set parameter.requires_grad = False
.
Example:
from nnunet.network_architecture.generic_UNet import Generic_UNet
model = Generic_UNet(input_channels=3, base_num_features=64, num_classes=4, num_pool=3)
for name, parameter in model.named_parameters():
if 'seg_outputs' in name:
print(f"parameter '{name}' will not be freezed")
parameter.requires_grad = True
else:
parameter.requires_grad = False
To check parameter names you can use print
:
print(model)
which produces:
Generic_UNet(
(conv_blocks_localization): ModuleList(
(0): Sequential(
(0): StackedConvLayers(
(blocks): Sequential(
(0): ConvDropoutNormNonlin(
(conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(instnorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(lrelu): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
)
(1): StackedConvLayers(
(blocks): Sequential(
(0): ConvDropoutNormNonlin(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(instnorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(lrelu): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
)
)
)
(conv_blocks_context): ModuleList(
(0): StackedConvLayers(
(blocks): Sequential(
(0): ConvDropoutNormNonlin(
(conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(dropout): Dropout2d(p=0.5, inplace=True)
(instnorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(lrelu): LeakyReLU(negative_slope=0.01, inplace=True)
)
(1): ConvDropoutNormNonlin(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(dropout): Dropout2d(p=0.5, inplace=True)
(instnorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(lrelu): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
)
(1): Sequential(
(0): StackedConvLayers(
(blocks): Sequential(
(0): ConvDropoutNormNonlin(
(conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(dropout): Dropout2d(p=0.5, inplace=True)
(instnorm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(lrelu): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
)
(1): StackedConvLayers(
(blocks): Sequential(
(0): ConvDropoutNormNonlin(
(conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(dropout): Dropout2d(p=0.5, inplace=True)
(instnorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(lrelu): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
)
)
)
(td): ModuleList(
(0): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
)
(tu): ModuleList(
(0): Upsample()
)
(seg_outputs): ModuleList(
(0): Conv2d(64, 4, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
Or you can visualize your network with netron
: