Search code examples
pythontensorflowcomputer-visionimage-comparisonsiamese-network

How to use trained siamese network to predict labels for large test set with 100+ classes?


Do I have to compare each test image to an example image from each class? The test set contains around 7400 images across 104 classes. So this would be 7400 x 104 predictions?

Using tensorflow on tpu's I was able to train the model pretty effectively. However, predicting the labels using the above method takes very long and additionally the model predict call causes memory leaks that eventually cause the kernel to fail (memory can blow up to 30+gb and counting).


Solution

  • There are multiple ways you can do this :

    • (No recommended) This is basically subset of of what you're actually doing. You can take some images from each class and compare it with your test image. Lets say you select 5 images from every class so you'll have to do 5*104 predictions.
    • You can use K - Nearest Neighbor model where you'll have to do the prediction of your 7400(or subset of these) images once only i.e create a KNN model and then directly use KNN Classifier to predict the class of the image.

    You can also refer to the Blog if you dont have a lot of idea about KNN or want to look at code implementations.