Search code examples
scikit-learnk-meansimage-segmentationunsupervised-learningsemantic-segmentation

KMeans clustering with labels data


I have a RGB image of shape (587, 987, 3). #height, width, num_channels

I also have label data (pixels' locations) for each of 7 classes.

I wanted to apply KMeans clustering algorithm to segment the given image into 7 classes. While applying KMeans clustering, I want to utilize the label data, i.e., pixels locations.

How can I utilize label data?

What I have tried so far is as follows.

img = np.random.randint(low=1,high=99, size=(587, 987, 3)) 

im = img.reshape(img.shape[0]*img.shape[1], img.shape[2])
im = StandardScaler().fit_transform(im)

clusters = KMeans(n_clusters=7,n_init= 100,max_iter=100,n_jobs=-1).fit(im)
kmeans_labels = clusters.labels_.reshape(img.shape[0], img.shape[1])

plt.imshow(kmeans_labels)
plt.show()
   

I'm looking for propagating some annotation to the remaining segments (superpixels)


Solution

  • As clarified in the comments of the question, you could treat the cluster as superpixels and propagate labels from a few samples to the remaining data, using some semi-supervised classifier [1].

    Creating an image to run the example:

    import numpy as np
    from skimage.data import binary_blobs
    import cv2 
    from pyift.shortestpath import seed_competition
    from scipy import sparse, spatial
    import matplotlib.pyplot as plt 
    
    # creating noisy image
    size = 256 
    image = np.empty((size, size, 3)) 
    image[:, :, 0] = binary_blobs(size, seed=0)
    image[:, :, 1] = binary_blobs(size, seed=0)
    image[:, :, 2] = binary_blobs(size, seed=1)
    image += np.random.randn(*image.shape) / 10
    image -= image.min()
    image /= image.max()
    
    plt.axis(False)
    plt.imshow(image)
    plt.show()
    

    noisy image

    Computing superpixels:

    def grid_seeds(image, rows = 15, cols = 15):
        seeds = np.zeros(image.shape[:2], dtype=np.int)
        v_step, h_step = image.shape[0] // rows, image.shape[1] // cols
        count = 1
        for i in range(rows):
            y = v_step // 2 + i * v_step
            for j in range(cols):
                x = h_step // 2 + j * h_step
                seeds[y, x] = count
                count += 1
        return seeds
                                                                                                     
    seeds = grid_seeds(image)
    _, _, _, superpixels = seed_competition(seeds, image=image)
    superpixels -= 1  # shifting labels to zero
    
    contours, _ = cv2.findContours(superpixels, cv2.RETR_FLOODFILL, cv2.CHAIN_APPROX_SIMPLE)
    im_w_contours = image.copy()
    cv2.drawContours(im_w_contours, contours, -1, (255, 0, 0))
    
    plt.axis(False)
    plt.imshow(im_w_contours)
    plt.show()
    

    superpixels

    Propagating labels from 4 arbitrary nodes, one for each class (color) and coloring the resulting labels with the expected color.

    def create_graph(image, labels):
        n_nodes = labels.max() + 1
        h, w, d = image.shape
        avg = np.zeros((n_nodes, d))
        for i in range(h):
            for j in range(w):
                avg[labels[i, j]] += image[i, j]
        avg[:] /= np.bincount(labels.flat)[:, np.newaxis]  # ignore label 0
        graph = spatial.distance_matrix(avg, avg)
        return sparse.csr_matrix(graph)
    
    graph = create_graph(image, superpixels)
    
    graph_seeds = np.zeros(graph.shape[0], dtype=np.int)
    graph_seeds[1] = 1   # blue training sample
    graph_seeds[3] = 2   # yellow training sample
    graph_seeds[13] = 3  # white training sample
    graph_seeds[14] = 4  # black training sample
    
    label_colors = {1: (0, 0, 255),
                    2: (255, 255, 0),
                    3: (255, 255, 255),
                    4: (0, 0, 0)}
    
    _, _, _, labels = seed_competition(graph_seeds, graph=graph)
    
    result = np.empty_like(image)
    for i, lb in enumerate(labels):
        result[superpixels == i] = label_colors[lb]
    
    plt.axis(False)
    plt.imshow(result)
    plt.show()
    

    result

    For this example, I used the difference between the average color of each superpixel as their arc-weight. However, in a real problem, some more elaborate feature vector will be necessary.

    Also, the labeled data is a subset of the image superpixels, but this is not strictly necessary, you can add any artificial node when modeling your graph, especially as the seed nodes.

    This approach is commonly used in remote sensing, this article might be relevant [2].

    [1] Amorim, W. P., Falcão, A. X., Papa, J. P., & Carvalho, M. H. (2016). Improving semi-supervised learning through optimum connectivity. Pattern Recognition, 60, 72-85.

    [2] Vargas, John E., et al. "Superpixel-based interactive classification of very high resolution images." 2014 27th SIBGRAPI Conference on Graphics, Patterns, and Images. IEEE, 2014.