Search code examples
pythonpython-3.xmatplotlibscikit-image

Color map not changing with imshow()


I am applying KMeans to an image, but when I try to use cmap to change the color, it doesn't do anything. How could I do it?

im = io.imread("image.jpg") / 255
x, y, z = im.shape
im_2D = im.reshape(x*y, z)
kmeans = KMeans(n_clusters=3, random_state=0)
kmeans.fit(im_2D)
im_clustered = kmeans.cluster_centers_[kmeans.labels_].reshape(x, y, z)

fig, ax = plt.subplots(1, 2)
ax[0].imshow(im)
ax[0].set_title("Original")
ax[1].imshow(im_clustered, cmap="jet")
ax[1].set_title("Segmented using k=3")
plt.show()

EDIT:

This is the output using the code above: enter image description here

This is what I would like the output to be, if using the jet cmap: enter image description here


Solution

  • You can use ax[1].imshow(kmeans.labels_.reshape(x, y), cmap='jet'). The current im_clustered contains rgb values. To apply a colormap you need scalar values.

    import numpy as np
    import matplotlib.pyplot as plt
    import matplotlib.cbook as cbook
    from sklearn.cluster import KMeans
    
    with cbook.get_sample_data('ada.png') as image_file:
        im = plt.imread(image_file)
    x, y, z = im.shape
    im_2D = im.reshape(x * y, z)
    kmeans = KMeans(n_clusters=3, random_state=0)
    kmeans.fit(im_2D)
    kmeans.cluster_centers_ = np.clip(kmeans.cluster_centers_, 0, 1)
    im_clustered = kmeans.cluster_centers_[kmeans.labels_].reshape(x, y, z)
    
    fig, ax = plt.subplots(1, 3, figsize=(10, 4))
    for ax_i in ax:
        ax_i.axis('off')ax[0].imshow(im)
    ax[0].set_title("Original")
    ax[1].imshow(im_clustered, cmap="jet")
    ax[1].set_title("Segmented using k=3")
    ax[2].imshow(kmeans.labels_.reshape(x, y), cmap="jet")
    ax[2].set_title("Segmented, k=3, jet cmap")
    plt.show()
    

    segmented image