Search code examples
pythonpython-3.xtensorflowobject-detection-api

coordinates of bounding box in tensorflow


I want the co-ordinates of the predicted bounding box from tensorflow models.
I am using object-detection script from here.
After following some answers on stackoverflow I modified the last block of detection as

for image_path in TEST_IMAGE_PATHS:
  image = Image.open(image_path)
  # the array based representation of the image will be used later in order to prepare the
  # result image with boxes and labels on it.
  image_np = load_image_into_numpy_array(image)
  # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
  image_np_expanded = np.expand_dims(image_np, axis=0)
  # Actual detection.
  output_dict = run_inference_for_single_image(image_np, detection_graph)
  # Visualization of the results of a detection.
  width, height = image.size
  print(width,height)
  ymin = output_dict['detection_boxes'][5][0]*height
  xmin = output_dict['detection_boxes'][5][1]*width
  ymax = output_dict['detection_boxes'][5][2]*height
  xmax = output_dict['detection_boxes'][5][3]*width
  #print(output_dict['detection_boxes'][0])
  print (xmin,ymin)
  print (xmax,ymax)

However there are 100 tuples in output_dict['detection_boxes'].
There 100 tuples even for those images on whom it failed to predict

What I want is co-ordinates of all the bounding boxes of single image.


Solution

  • After expand_dims line you can add these codes. filtered_boxes variable will give bounding boxes whose prediction values are more than 0.5.

      (boxes, scores, classes, num) = sess.run(
          [detection_boxes, detection_scores, detection_classes, num_detections],
          feed_dict={image_tensor: image_np_expanded})
      indexes = []
      import os
      for i in range (classes.size):
        if(classes[0][i] in range(1,91) and scores[0][i]>0.5):
            indexes.append(i)
      filtered_boxes = boxes[0][indexes, ...]
      filtered_scores = scores[0][indexes, ...]
      filtered_classes = classes[0][indexes, ...]
      filtered_classes = list(set(filtered_classes))
      filtered_classes = [int(i) for i in filtered_classes]