Search code examples
pythonscikit-image

Python: how could this image be properly segmented?


I would like to segment (isolate) the rod-like structures shown in this image:

enter image description here

The best I've managed to do is this

# Imports the libraries.
from skimage import io, filters
import matplotlib.pyplot as plt
import numpy as np

# Imports the image as a numpy array.
img = io.imread('C:/Users/lopez/Desktop/Test electron/test.tif')

# Thresholds the images using a local threshold.
thresh = filters.threshold_local(img,301,offset=0)
binary_local = img > thresh # Thresholds the image
binary_local = np.invert(binary_local) # inverts the thresholded image (True becomes False and vice versa).

# Shows the image.
plt.figure(figsize=(10,10))
plt.imshow(binary_local,cmap='Greys')
plt.axis('off')
plt.show()

Which produces this result

enter image description here

However, as you can see from the segmented image, I haven't managed to isolate the rods. What should be black background is filled with interconnected structures. Is there a way to neatly isolate the rod-like structures from all other elements in the image?

The original image can be downloaded from this website

https://dropoff.nbi.ac.uk/pickup.php

Claim ID: qMNrDHnfEn4nPwB8

Claim Passcode: UkwcYoYfXUfeDto8


Solution

  • Here is my attempt using a Meijering filter. The Meijering filter relies on symmetry when it looks for tubular structures and hence the regions where rods overlap (breaking the symmetry of the tubular shape) are not that well recovered, as can be seen in the overlay below.

    Also, there is some random crap that I have trouble getting rid off digitally, but maybe you can clean your prep a bit more before imaging.

    enter image description here

    #!/usr/bin/env python
    
    import numpy as np
    import matplotlib.pyplot as plt
    
    from skimage.io import imread
    from skimage.transform import rescale
    from skimage.restoration import denoise_nl_means
    from skimage.filters import meijering
    from skimage.measure import label
    from skimage.color import label2rgb
    
    
    def remove_small_objects(binary_mask, size_threshold):
        label_image = label(binary_mask)
        object_sizes = np.bincount(label_image.ravel())
        labels2keep, = np.where(object_sizes > size_threshold)
        labels2keep = labels2keep[1:] # remove the first label, which corresponds to the background
        clean = np.in1d(label_image.ravel(), labels2keep).reshape(label_image.shape)
        return clean
    
    
    if __name__ == '__main__':
    
        raw = imread('test.tif')
        raw -= raw.min()
        raw /= raw.max()
    
        # running everything on the large image took too long for my patience;
        raw = rescale(raw, 0.25, anti_aliasing=True)
    
        # smooth image while preserving edges
        smoothed = denoise_nl_means(raw, h=0.05, fast_mode=True)
    
        # filter for tubular shapes
        sigmas = range(1, 5)
        filtered = meijering(smoothed, sigmas=sigmas, black_ridges=False)
        # Meijering filter always evaluates to high values at the image frame;
        # we hence set the filtered image to zero at those locations
        frame = np.ones_like(filtered, dtype=np.bool)
        d = 2 * np.max(sigmas) + 1 # this is the theoretical minimum ...
        d += 2 # ... but doesn't seem to be enough so we increase d
        frame[d:-d, d:-d] = False
        filtered[frame] = np.min(filtered)
    
        thresholded = filtered > np.percentile(filtered, 80)
        cleaned = remove_small_objects(thresholded, 200)
    
        overlay = raw.copy()
        overlay[np.invert(cleaned)] = overlay[np.invert(cleaned)] * 2/3
    
        fig, axes = plt.subplots(2, 3, sharex=True, sharey=True)
        axes = axes.ravel()
        axes[0].imshow(raw, cmap='gray')
        axes[1].imshow(smoothed, cmap='gray')
        axes[2].imshow(filtered, cmap='gray')
        axes[3].imshow(thresholded, cmap='gray')
        axes[4].imshow(cleaned, cmap='gray')
        axes[5].imshow(overlay, cmap='gray')
    
        for ax in axes:
            ax.axis('off')
    
        fig, ax = plt.subplots()
        ax.imshow(overlay, cmap='gray')
        ax.axis('off')
        plt.show()
    

    enter image description here

    If this code makes it into a paper, I want an acknowledgement and a copy of the paper. ;-)