Search code examples
pythontensorflowobject-detectionobject-detection-api

Improving a pre-trained tensorflow object detection model


I want to use tensorflow for detecting cars in an embedded system, so I tried ssd_mobilenet_v2 and it actually did pretty well for me, except for some specific car types which are not very common and I think that is why the model does not recognize them. I have a dataset of these cases and I want to improve the model by fine-tuning it. I should also note that I need a .tflite file because I'm using tflite_runtime in python. I followed these instructions https://github.com/EdjeElectronics/TensorFlow-Object-Detection-API-Tutorial-Train-Multiple-Objects-Windows-10 and I could train the model and reached a reasonable loss value. I then used export_tflite_ssd_graph.py in the object detection API to build inference_graph from the trained model. Afterwards I used toco tool to build a .tflite file out of it.

But here is the problem, after I've done all that; not only the model did not improve, but now it does not detect any cars. I got confused and do not know what is the problem, I searched a lot and did not find any tutorial about doing what I need to do. They just added a new object to a model and then exported it, which I tried and I was successful doing that. I also tried to build a .tflite file without training the model and directly from the Tensorflow detection model zoo and it worked fine. So I think the problem has something to do with the training process. Maybe I am missing something there.

Another thing that I did not find in documents is that whether is it possible to "add" a class to the current classes of an object detection model. For example, let's assume the mobilenet ssd v2 detects 90 different object classes, I would like to add another class so that the model detects 91 different classes instead of 90 classes. As far as I understand and tested after doing transfer learning using object detection API, I could only detect the objects that I had in my dataset and the old classes will be gone. So how do I do what I explained?


Solution

  • I found out that there is no way to 'add' a class to the previously trained classes but with providing a little amount of data of that class you can have your model detect it. The reason is that the last layer of the model changes when transfer learning is applied. In my case I labeled around 3k frames containing about 12k objects because my frames would be complicated. But for simpler tasks as I saw in tutorials 200-300 annotated images would be enough.

    And for the part that the model did not detect anything it has something to do with the convert command that I used. I should have used tflite_convert instead of toco. I explained more here.