Search code examples
pythonopencvimage-processingscikit-learncluster-analysis

How to Cluster Parts of a Mask in an Image Using Python?


I need to split a mask in such a way that if there is an inconsistency within the mask, it gets separated. For example, if I draw a mask on a cat, I want the wide part (the body) to be one mask and the narrow part (the tail) to be another.

Currently, I have a continuous mask that includes both the cat's body and its tail. I want to separate this into two distinct masks. How can I achieve this using Python?

original mask

desired mask

I looked into using methods described in this, which focuses on polygon partitioning and separating contours into multiple triangles. However, this approach does not suit my needs, as I want to split the mask based on size and shape rather than creating multiple triangular partitions.


Solution

  • You can use Convexity defects to identify the points to "cut" between.

    It is done like this here for example.

    import cv2
    import matplotlib.pyplot as plt
    import numpy as np
    
    def split_mask(mask):
        _, thresh = cv2.threshold(mask , 120,255, cv2.THRESH_BINARY)
        contours, _ = cv2.findContours(thresh, 2, 1)
    
        for contour in contours:
            if  cv2.contourArea(contour) > 20:
                hull = cv2.convexHull(contour, returnPoints = False)
                defects = cv2.convexityDefects(contour, hull)
                if defects is None:
                    continue
                
                # Gather all defect points to filter them.
                dd = [e[0][3]/256 for e in defects]
                points = []
                for i in range(len(dd)):
                    _,_,f,_ = defects[i,0]
                    if dd[i] > 1.0 and dd[i]/np.max(dd) > 0.2:
                        points.append(f)
    
                # If there is defect points, erase the mask closest points.
                if len(points) >= 2:
                    for i, f1 in enumerate(points):
                        p1 = tuple(contour[f1][0])
                        nearest = min((tuple(contour[f2][0]) for j, f2 in enumerate(points) if i != j),
                                      key=lambda p2: (p1[0] - p2[0])**2 + (p1[1] - p2[1])**2)
                        cv2.line(thresh, p1, nearest, 0, 20)
        return thresh    
    
    if __name__=="__main__":
        mask = cv2.imread("<path-to-your-image>", cv2.IMREAD_GRAYSCALE)
        mask_splitted = split_mask(mask)
        plt.imshow(mask_splitted)
        plt.show()
    

    This yield the following on your image: enter image description here