Search code examples
deep-learningpytorchtransfer-learningsemantic-segmentation

Transfer Learning Segmentation Model Perfoming Significantly Worse on Test Data


I am quite new to the field of semantic segmentation and have recently tried to run the code provided on this paper: Transfer Learning for Brain Tumor Segmentation that was made available on GitHub. It is a semantic segmentation task that uses the BraTS2020 dataset, comprising of 4 modalities, T1, T1ce, T2 and FLAIR. The author utilised a transfer learning approach using Resnet34 weights.

Due to hardware constraints, I had to half the batch size from 24 to 12. However, after training the model, I noticed a significant drop in performance, with the Dice Score (higher is better) of the 3 classes being only around 5-19-11 as opposed to the reported result of 78-87-82 in the paper. The training and validation accuracies however, seem to be performing normally, just that the model does not perform well on test data, I selected the model that was produced before overfitting (validation loss starts increasing but training loss still decreasing) but yielded equally bad results.

So far I have tried:

  1. Decreasing the learning rate from 1e-3 to 1e-4, yielded similar results
  2. Increased the number of batches fed to the model per training epoch to 200 batches per epoch, to match the number of iterations ran in the paper since I effectively halved the batch size - (100 batches per epoch, batch size of 24)

I noticed that image augmentations were applied to the training and validation dataset to increase the robustness of the model training. Do these augmentations need to be performed on the test set in order to make predictions? There are no resizing transforms, transforms that are present are Gaussian Blur and Noise, change in brightness intensity, rotations, elastic deformation, and mirroring, all implemented using the example here.

I'd greatly appreciate help on these questions:

  1. By doubling the number of batches per epoch, it effectively matches the number of iterations performed as in the original paper since the batch size is halved. Is this the correct approach?

  2. Does the test set data need to be augmented similarly to the training data in order to perform predictions? (Note: no resizing transformations were performed)


Solution

    1. Technically, for a smaller batch the number of iterations should be higher for convergence. So, your approach is going to help, but it probably won't give the same performance boost as doubling the batch size.

    enter image description here

    1. Usually, we don't use augmentation on test data. But if the transformation applied on training and validation is not applied to the test data, the test performance will be poor, no doubt. You can try test time augmentation though, even though it's not very common for segmentation tasks https://github.com/qubvel/ttach