Search code examples
pythonimage-segmentationsemantic-segmentationyolov8

How to create a binary mask from a yolo8 segmentation result


I want to segment an image using yolo8 and then create a mask for all objects in the image with specific class.

I have developed this code:

img=cv2.imread('images/bus.jpg')
model = YOLO('yolov8m-seg.pt')
results = model.predict(source=img.copy(), save=False, save_txt=False)
class_ids = np.array(results[0].boxes.cls.cpu(), dtype="int")
for i in range(len(class_ids)):
    if class_ids[i]==0:
         empty_image = np.zeros((height, width,3), dtype=np.uint8)
         res_plotted = results[0][i].plot(boxes=0, img=empty_image)

In the above code, res_plotted is the mask for one object, in RGB. I want to add all of these images to each other and create a mask for all objects with class 0 (it is a pedestrian in this example)

My questions:

  1. How can I complete this code?
  2. Is there any better way to achieve this without having a loop?
  3. Is there any utility function in the yolo8 library to do this?

Solution

  • Extract the people segmentations using the bbox classes. You will get an array of shape [channels, w, h]. Then you can use any over the channel dimension (which is equal to the number of people) to flatten the multi-channel array into a single channel array.

    import cv2
    from ultralytics import YOLO
    import numpy as np
    import torch
    
    
    img= cv2.imread('ultralytics/assets/bus.jpg')
    model = YOLO('yolov8m-seg.pt')
    results = model.predict(source=img.copy(), save=True, save_txt=False, stream=True)
    for result in results:
        # get array results
        masks = result.masks.data
        boxes = result.boxes.data
        # extract classes
        clss = boxes[:, 5]
        # get indices of results where class is 0 (people in COCO)
        people_indices = torch.where(clss == 0)
        # use these indices to extract the relevant masks
        people_masks = masks[people_indices]
        # scale for visualizing results
        people_mask = torch.any(people_masks, dim=0).int() * 255
        # save to file
        cv2.imwrite(str(model.predictor.save_dir / 'merged_segs.jpg'), people_mask.cpu().numpy())
    

    Input w bboxes and segmentations / Output:

    enter image description here enter image description here

    Everything is computed on GPU with internal torch operations for maximum performance