Search code examples
pythonpandasmatplotlibscikit-learncluster-analysis

Plot cluster matrix


I want to plot a cluster matrix from K-means from scikit-learn using the following pandas dataframe:

from sklearn.datasets import load_breast_cancer
cancer = load_breast_cancer() # toy dataset
data = pd.DataFrame(cancer.data, columns=[cancer.feature_names])
df = data.iloc[:,4:8] #select subset
df.columns = ['smoothness', 'compactness', 'concavity', 'concave points'] 
df

+----+--------------+---------------+-------------+------------------+
|    |   smoothness |   compactness |   concavity |   concave points |
|----+--------------+---------------+-------------+------------------|
|  0 |      0.1184  |       0.2776  |      0.3001 |          0.1471  |
|  1 |      0.08474 |       0.07864 |      0.0869 |          0.07017 |
|  2 |      0.1096  |       0.1599  |      0.1974 |          0.1279  |
|  3 |      0.1425  |       0.2839  |      0.2414 |          0.1052  |
|  4 |      0.1003  |       0.1328  |      0.198  |          0.1043  |
+----+--------------+---------------+-------------+------------------+

Solution

  • IIUC, you could simplify using seaborn.pairplot and pass in Kmeans.label_ as the hue argument. For example:

    import seaborn as sns
    from sklearn.cluster import KMeans
    
    def kmeans_scatterplot(df, n_clusters):
        km = KMeans(init='k-means++', n_clusters=n_clusters)
        km_clustering = km.fit(df)
        sns.pairplot(df.assign(hue=km_clustering.labels_), hue='hue')
    
    kmeans_scatterplot(df, 2)
    

    [out]

    enter image description here