Search code examples
pythonmatplotlibk-means

How to label cluster after applying to k-mean clustering to dataset?


I have a dataset in .csv format which looks like this - data

x,y,z, label
2,1,3, A
5,3,1, B
6,2,2, C
9,5,3, B
2,3,4, A
4,1,4, A

I would like to apply k-mean clustering to the above dataset. As we see above the 3 dimension dataset(x-y-z). And after that, I would like to visualize the clustering in 3-dimension with a specific cluster label in diagram. Please let know if you need more details.

I have used for 2-dimension dataset as see below -

kmeans_labels = cluster.KMeans(n_clusters=5).fit_predict(data)

And plot the visualize for 2-dimension dataset,

plt.scatter(standard_embedding[:, 0], standard_embedding[:, 1], c=kmeans_labels, s=0.1, cmap='Spectral');

Similarly, I would like to plot 3-dimension clustering with label. Please let me know if you need more details.


Solution

  • Could something like that be a good solution?

    import numpy as np
    from sklearn.cluster import KMeans
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D
    
    data = np.array([[2,1,3], [5,3,1], [6,2,2], [9,5,3], [2,3,4], [4,1,4]])
    
    cluster_count = 3
    km = KMeans(cluster_count)
    clusters = km.fit_predict(data)
    
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    scatter = ax.scatter(data[:, 0], data[:, 1], data[:, 2], c=clusters, alpha=1)
    
    labels = ["A", "B", "C"]
    for i, label in enumerate(labels):
        ax.text(km.cluster_centers_[i, 0], km.cluster_centers_[i, 1], km.cluster_centers_[i, 2], label)
    
    ax.set_title("3D K-Means Clustering")
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_zlabel("z")
    plt.show()
    

    enter image description here

    EDIT

    If you want a legend instead, just do this:

    import numpy as np
    from sklearn.cluster import KMeans
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D
    
    data = np.array([[2,1,3], [5,3,1], [6,2,2], [9,5,3], [2,3,4], [4,1,4]])
    
    cluster_count = 3
    km = KMeans(cluster_count)
    clusters = km.fit_predict(data)
    
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    scatter = ax.scatter(data[:, 0], data[:, 1], data[:, 2], c=clusters, alpha=1)
    
    handles = scatter.legend_elements()[0]
    ax.legend(title="Clusters", handles=handles, labels = ["A", "B", "C"])
    
    ax.set_title("3D K-Means Clustering")
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_zlabel("z")
    plt.show()
    

    enter image description here