Search code examples
pythonpost-processingnumpy-ndarray

Remove connected components below a threshold in a 3-d array


I am working on 3-D numpy array in python and want to do post-processing on CNN output of Brain Tumor Segmentation images. We get a 3-D (208x208x155) numpy array with values as 0/1/2/4 for each pixel. I want to remove the connected components with a threshold less than 1000 for better results.

I tried erosion-dilation but don't get good results. Can anyone help me?


Solution

  • Ok, so shrink and grow will, as you realised yourself, not be the way to approach this problem. What you need to do is region labelling, and it seems that Scipy has a method that will let you do that for nd images.

    I assume that by threshold less than 1000 you mean sum of the pixel values in the connected components.

    Here is an outline of how I would do it.

    from scipy.ndimage import label
    
    segmentation_mask = [...]  # This should be your 3D mask.
    
    # Let us create a binary mask.
    # It is 0 everywhere `segmentation_mask` is 0 and 1 everywhere else.
    binary_mask = segmentation_mask.copy()
    binary_mask[binary_mask != 0] = 1
    
    # Now, we perform region labelling. This way, every connected component
    # will have their own colour value.
    labelled_mask, num_labels = label(binary_mask)
    
    # Let us now remove all the too small regions.
    refined_mask = segmentation_mask.copy()
    minimum_cc_sum = 1000
    for label in range(num_labels):
        if np.sum(refined_mask[labelled_mask == label]) < minimum_cc_sum:
            refined_mask[labelled_mask == label] = 0