Search code examples
scikit-image

How does skimage.segmentation.slic achieve segmentation under non-binary masks?


Slic can implement segmentation under binarized masks, as shown in the figure below

from https://scikit-image.org/docs/dev/auto_examples/segmentation/plot_mask_slic.html

But if I need to divide the superpixels of different adjacent regions, what should I do?

Each color represents an area, each region requires independent superpixel segmentation


Solution

  • There is not currently any way to handle a mask with multiple regions in a single call. For your use case you will have to split each region into a separate mask and then call slic once per mask. You can combine the multiple segmentations into one by incrementing the labels appropriately.

    Pasted below is a concrete example of this for two separate masked regions (adapted from the existing example you referenced):

    
    import matplotlib.pyplot as plt
    import numpy as np
    
    from skimage import data
    from skimage import color
    from skimage import morphology
    from skimage import segmentation
    
    # Input data
    img = data.immunohistochemistry()
    
    # Compute a mask
    lum = color.rgb2gray(img)
    mask = morphology.remove_small_holes(
        morphology.remove_small_objects(
            lum < 0.7, 500),
        500)
    
    mask1 = morphology.opening(mask, morphology.disk(3))
    # create a second mask as the inverse of the first
    mask2 = ~mask1
    
    segmented = np.zeros(img.shape[:-1], dtype=np.int64)
    max_label = 0
    # replace [mask2, mask1] with a list of any number of binary masks
    for mask in [mask2, mask1]:
    
        # maskSLIC result
        m_slic = segmentation.slic(img, n_segments=100, mask=mask, start_label=1)
        if max_label > 0:
            # offset the labels by the current maximum label
            m_slic += max_label
        # add the label into the current combined segmentation
        segmented += m_slic
        # increment max label
        max_label += m_slic.max()
    
    
    # Display result
    fig, ax_arr = plt.subplots(2, 2, sharex=True, sharey=True, figsize=(10, 10))
    ax1, ax2, ax3, ax4 = ax_arr.ravel()
    
    ax1.imshow(img)
    ax1.set_title('Original image')
    
    ax2.imshow(mask, cmap='gray')
    ax2.set_title('Mask')
    
    ax3.imshow(segmentation.mark_boundaries(img, m_slic))
    ax3.contour(mask, colors='red', linewidths=1)
    ax3.set_title('maskSLIC (mask1 only)')
    
    ax4.imshow(segmentation.mark_boundaries(img, segmented))
    ax4.contour(mask, colors='red', linewidths=1)
    ax4.set_title('maskSLIC (both masks)')
    
    for ax in ax_arr.ravel():
        ax.set_axis_off()
    
    plt.tight_layout()
    plt.show()
    

    The basic approach I am suggesting is in the for loop above. Most of the other code is just generating the data and plots.