Search code examples
pythontensorflowcropimage-segmentationobject-detection-api

Tensorflow Object Detection API cropping segments of an image


I am using Tensorflow Object Detection API with models that could detect objects with bounding boxes and masks.

Here is my code:

def run_inference_for_single_image_raw(image, graph):
  with graph.as_default():
    with tf.Session() as sess:
      ops = tf.get_default_graph().get_operations()
      all_tensor_names = {output.name for op in ops for output in op.outputs}
      tensor_dict = {}
      for key in [
          'num_detections', 'detection_boxes', 'detection_scores',
          'detection_classes', 'detection_masks'
      ]:
        tensor_name = key + ':0'
        if tensor_name in all_tensor_names:
          tensor_dict[key] = tf.get_default_graph().get_tensor_by_name(
              tensor_name)
      if 'detection_masks' in tensor_dict:
        detection_boxes = tf.squeeze(tensor_dict['detection_boxes'], [0])
        detection_masks = tf.squeeze(tensor_dict['detection_masks'], [0])
        real_num_detection = tf.cast(tensor_dict['num_detections'][0], tf.int32)
        detection_boxes = tf.slice(detection_boxes, [0, 0], [real_num_detection, -1])
        detection_masks = tf.slice(detection_masks, [0, 0, 0], [real_num_detection, -1, -1])
        detection_masks_reframed = utils_ops.reframe_box_masks_to_image_masks(
            detection_masks, detection_boxes, image.shape[0], image.shape[1])
        detection_masks_reframed = tf.cast(
            tf.greater(detection_masks_reframed, 0.5), tf.uint8)
        tensor_dict['detection_masks'] = tf.expand_dims(
            detection_masks_reframed, 0)
      image_tensor = tf.get_default_graph().get_tensor_by_name('image_tensor:0')

      output_dict = sess.run(tensor_dict,
                             feed_dict={image_tensor: np.expand_dims(image, 0)})

  return output_dict

So if I run following code:

vis_util.visualize_boxes_and_labels_on_image_array(
      image,
      output_dict['detection_boxes'],
      output_dict['detection_classes'],
      output_dict['detection_scores'],
      category_index,
      instance_masks=output_dict.get('detection_masks'),
      use_normalized_coordinates=True,
      line_thickness=8)
plt.figure(figsize=(12, 8))
plt.grid(False)
plt.imshow(image)

The result is (Image with bounding boxes and masks): Image with bounding boxes and masks

So, how can i crop image objects by mask path, not bounding box, so here in this example i want to have output images only with object (cat/bottle) on transparent background. (May be using PIL or OpenCV etc)


Solution

  • So ifoutput_dict.get('detection_masks') is numpy object and actually the binary mask, you can crop the image by simply multiplying or using np.where

    mask = output_dict.get('detection_masks')
    img_cropped = img * mask
    

    This will crop all the detected objects but if you want to individually crop the objects, there is a way by detecting contours. We can use scikit-image for this

    from skimage import measure
    label_mask = measure.label(mask)
    

    We have now labelled all the connected components in the binary image and assigned numeric labelled to each (by changing the pixel values). The labels start from '1' and ends at number of objects.

    single_object_mask = (label_mask == 1) #or 2, 3...
    

    This will filter the label_mask image with label you provided. You can also use the Bounding Box info to crop for particular object(s).