Search code examples
deep-learningclassificationpre-trained-modeltransfer-learning

How to choose which pre-trained weights to use for my model?


I am a beginner, and I am very confused about how we can choose a pre-trained model that will improve my model.

I am trying to create a cat breed classifier using pre-trained weights of a model, lets say VGG16 trained on digits dataset, will that improve the performance of the model? or if I train my model just on the database without using any other weights will be better, or will both be the same as those pre-trained weights will be just a starting point.

Also if I use weights of the VGG16 trained for cat vs dog data as a starting point of my cat breed classification model will that help me in improving the model?


Solution

  • Sane weight initialization

    The pre-trained weights to choose depends upon the type of classes you wish to classify. Since, you wish to classify Cat Breeds, use pre-trained weights from a classifier that is trained on similar task. As mentioned by the above answers the initial layers learn things like edges, horizontal or vertical lines, blobs, etc. As you go deeper, the model starts learning problem specific features. So for generic tasks you can use say imagenet & then fine-tune it for the problem at hand.

    However, having a pre-trained model which closely resembles your training data helps immensely. A while ago, I had participated in Scene Classification Challenge where we initialized our model with the ResNet50 weights trained on Places365 dataset. Since, the classes in the above challenge were all present in the Places365 dataset, we used the weights available here and fine-tuned our model. This gave us a great boost in our accuracy & we ended up at top positions on the leaderboard. You can find some more details about it in this blog

    Also, understand that the one of the advantages of transfer learning is saving computations. Using a model with randomly initialized weights is like training a neural net from scratch. If you use VGG16 weights trained on digits dataset, then it might have already learned something, so it will definitely save some training time. If you train a model from scratch then it will eventually learn all the patterns which using a pre-trained digits classifier weights would have learnt.

    On the other hand using weights from a Dog-vs-Cat classifier should give you better performance as it already has learned features to detect say paws, ears, nose or whiskers.