Search code examples
deep-learningpytorchimage-segmentation

Which parameters of Mask-RCNN control mask recall?


I'm interested in fine-tuning a Mask-RCNN model that I'm using for instance segmentation. Currently I have trained the model for 6 epochs and the various Mask-RCNN losses are as follows:

enter image description here

The reason I'm stopping is that the COCO evaluation metrics seem to have dipped in the last epoch:

enter image description here

I know this is a far reaching question, but I'm looking to gain some intuition of how to understand which parameters are going to be the most impactful in improving the evaluation metrics. I understand there are three places to consider:

  1. Should I be looking at batch size, learning rate and momentum, this uses an SGD optimizer with a learning rate of 1e-4 and batch size 2?
  2. Should I be looking at using more training data or adding augmentation (I don't currently use any) and my dataset is current pretty large 40K images?
  3. Should I be looking at the specific MaskRCNN parameters?

I thing I'll likely be asked to me more specific on what I want to improve so let me say that I would like to improve the recall of the individual masks. The model is performing well but doesn't quite capture the full extend of what I would like it to. I'm also leaving out details of the specific learning problem as I'd like to gain intuition of how to approach this in general.


Solution

  • A couple of notes:

    • 6 epochs are too small for the network to converge even if you use a pre-trained network—especially such a big one as resnet50. I think you need at least 50 epochs. On a pre-trained resnet18 I started to get good results after 30 epochs, resnet34 needed +10-20 epochs and your resnet50 + 40k images of the train set - definitely need more epochs than 6;

    • definitely use a pre-trained network;

    • in my experience, I failed to get the results I like with SGD. I started using AdamW + ReduceLROnPlateau scheduler. The network converges quite fast, like 50-60% AP on epoch 7 or 8 but then it comes up to 80-85 after 50-60 epochs using very small improvements from epoch to epoch, only if the LR is small enough. You must be familiar with the gradient descent notion. I used to think of it as if you have more augmentation, your "hill" is covered with "boulders" that you have to be able to bypass and this is only possible if you control the LR. Additionally, AdamW helps with the overfitting. This is how I do it. For networks with higher input resolution (your input images are scaled on input by the net itself), I use higher LR.

      init_lr = 0.00005 weight_decay = init_lr * 100 optimizer = torch.optim.AdamW(params, lr=init_lr, weight_decay=weight_decay) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=3, factor=0.75)

      for epoch in range(epochs): # train for one epoch, printing every 10 iterations metric_logger = train_one_epoch(model, optimizer, train_loader, scaler, device, epoch, print_freq=10)

        scheduler.step(metric_logger.loss.global_avg)
        optimizer.param_groups[0]["weight_decay"] = optimizer.param_groups[0]["lr"] * 100
      
        # scheduler.step()
      
        # evaluate on the test dataset
        evaluate(model, test_loader, device=device)
      
        print("[INFO] serializing model to '{}' ...".format(args["model"]))
        save_and_print_size_of_model(model, args["model"], script=False)
      

    Find such an LR and weight decay that the training exhausts LR to a very small value, like 1/10 of your initial LR, at the end of the training. If you will have a plateau too often, the scheduler quickly brings it to very small values and the network will learn nothing all the rest of the epochs.

    Your plots indicate that your LR is too high at some point in the training, the network stops training and then AP is going down. You need constant improvements, even small ones. The more network trains the more subtle details it learns about your domain and the smaller the learning rate. Imho, constant LR will not allow doing that correctly.

    • anchor generator settings. Here is how I initialize the network.

       def get_maskrcnn_resnet_model(name, num_classes, pretrained, res='normal'):
            print('Using maskrcnn with {} backbone...'.format(name))
      
      
            backbone = resnet_fpn_backbone(name, pretrained=pretrained, trainable_layers=5)
      
      
            sizes = ((4,), (8,), (16,), (32,), (64,))
            aspect_ratios = ((0.25, 0.5, 1.0, 2.0, 4.0),) * len(sizes)
            anchor_generator = AnchorGenerator(
                sizes=sizes, aspect_ratios=aspect_ratios
            )
      
            roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0', '1', '2', '3'],
                                                            output_size=7, sampling_ratio=2)
      
            default_min_size = 800
            default_max_size = 1333
      
            if res == 'low':
                min_size = int(default_min_size / 1.25)
                max_size = int(default_max_size / 1.25)
            elif res == 'normal':
                min_size = default_min_size
                max_size = default_max_size
            elif res == 'high':
                min_size = int(default_min_size * 1.25)
                max_size = int(default_max_size * 1.25)
            else:
                raise ValueError('Invalid res={} param'.format(res))
      
            model = MaskRCNN(backbone, min_size=min_size, max_size=max_size, num_classes=num_classes,
                             rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler)
      
            model.roi_heads.detections_per_img = 512
            return model
      

    I need to find small objects here why I use such anchor params.

    • classes in-balancing issue. If you have only your object and bg - no problem. If you have more classes then make sure that your training split (as 80% for train and 20% for the test) is more or less precisely applied to all the classes used in your particular training.

    Good luck!