Search code examples
pythontensorflowobject-detectionbounding-boxobject-detection-api

Get rid of overlapping bounding boxes across different classes in Tensorflow Object Detection API


I am using the Tensorflow Object Detection API to train my own vehicle detector. When I tested my model using the Object detection tutorial, I found that there are instances when a truck is detected as both a car and a truck with two overlapping bounding boxes around it. I only want to leave the one with the highest detection score.

I know that the Object Detection API does get rid of overlapping bounding boxes, but it does not do so for bounding boxes across different classes. Is there a way to get rid of the overlapping boxes? Is there any place in the Object Detection API code that I can change to achieve that?


Solution

  • You can use non_max_suppression over all classes:

      corners = tf.constant(boxes, tf.float32)
      boxesList = box_list.BoxList(corners)
      boxesList.add_field('scores', tf.constant(scores))
      iou_thresh = 0.1
      max_output_size = 100
      sess = tf.Session()
      nms = box_list_ops.non_max_suppression(
          boxesList, iou_thresh, max_output_size)
      boxes = sess.run(nms.get())