As per the title mentioned, if I have already pretrained backbone, and I want to train only the RPN instead of the classifier using the Faster R-CNN from torchvision.
Is there any parameters I can pass in to the create_model function or would I stop the classifier training in my train() function?
I’m on mobile so olease excuse my editting
This is my create model function
Create your backbone from timm
backbone = timm.create_model(
num_classes=0, # this is important to remove fc layers
global_pool="" # this is important to remove fc layers
backbone.out_channels = backbone.feature_info[-1][“num_chs”]
anchor_generator = AnchorGenerator(
sizes=((16, 32, 64, 128, 256),), aspect_ratios=((0.25, 0.5, 1.0, 2.0),)
roi_pooler = torchvision.ops.MultiScaleRoIAlign(
featmap_names=[“0”], output_size=7, sampling_ratio=2
fastercnn_model = FasterRCNN(
You can do the following
# First you can use model.children() method to see the idx of the backbone
for idx, child in enumerate(fastercnn_model.children()):
if idx == 1:
# Now set requires_grad for that idx to False
for param in child.parameters():
param.requires_grad = False
# =============== UPDATED ========================
# This will train only the box_predictor not even the RPN. You can try out
# Different strategies and find the best for you.
# setting everything to false
for child in fastercnn_model.children():
for param in child.parameters():
param.requires_grad = False
for idx, child in enumerate(fastercnn_model.children()):
if idx == 3:
for i, param in enumerate(child.parameters()):
if i==1:
param.requires_grad = True