Search code examples
pythonk-means

How to save specific color image from the output of Kmeans using python


I am using the below code for color-based segmentation using K-means. In this code, each cluster is saving into one image. In my case requirement is a bit different. I want to save only blue color images. Could you please help me how can I save only blue color images?

import numpy as np
import cv2
import pdb
from matplotlib import pyplot as plt
img = cv2.imread('a.png')
Z = np.float32(img.reshape((-1,3)))

criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
K = 4
_,labels,centers = cv2.kmeans(Z, K, None, criteria, 10, cv2.KMEANS_RANDOM_CENTERS)
labels = labels.reshape((img.shape[:-1]))
reduced = np.uint8(centers)[labels]

result = [np.hstack([img, reduced])]
for i, c in enumerate(centers):
    mask = cv2.inRange(labels, i, i)
    mask = np.dstack([mask]*3) # Make it 3 channel
    ex_img = cv2.bitwise_and(img, mask)
    ex_reduced = cv2.bitwise_and(reduced, mask)
    result.append(np.hstack([ex_img, ex_reduced]))
    pdb.set_trace()

cv2.imwrite('watermelon_out.jpg', np.vstack(result))

Original Image Original Image

After using this code I am getting result link below: aa

Expected Result: Expected Result


Solution

  • This should only print the blue image. First find the center which is closest to blue color and then plot points only in cluster represented by that center

    import numpy as np
    import cv2
    import pdb
    from matplotlib import pyplot as plt
    img = cv2.imread('a.png')
    Z = np.float32(img.reshape((-1,3)))
    
    criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
    K = 4
    _,labels,centers = cv2.kmeans(Z, K, None, criteria, 10, cv2.KMEANS_RANDOM_CENTERS)
    labels = labels.reshape((img.shape[:-1]))
    reduced = np.uint8(centers)[labels]
    
    
    blue_dis = 99999999
    blue_center = -1
    b = (255, 50 , 0)
    for i, c in enumerate(centers):
        dis = (c[0]-b[0])**2 + (c[1]-b[1])**2 + (c[1]-b[1])**2
        if dis < blue_dis:
            blue_center = i
            blue_dis = dis
    
    
        
    
    result = [np.hstack([img, reduced])]
    for i, c in enumerate(centers):
        if i!=blue_center:
            continue
        mask = cv2.inRange(labels, i, i)
        mask = np.dstack([mask]*3) # Make it 3 channel
        ex_img = cv2.bitwise_and(img, mask)
        ex_reduced = cv2.bitwise_and(reduced, mask)
        result.append(np.hstack([ex_img, ex_reduced]))
        pdb.set_trace()
    
    cv2.imwrite('watermelon_out.jpg', np.vstack(result))