Search code examples
pythonscikit-learnscipyhierarchical-clustering

SciPy Dendrogram Plotting


I am playing with hierarchical documents clustering and actually my workflow is nearly this:

df = pandas.read_csv(file, delimiter='\t', index_col=0) # documents-terms matrix (very sparse)
dist_matrix = cosine_similarity(df)

linkage_matrix = ward(dist_matrix)
labels = fcluster(linkage_matrix, 5, criterion='maxclust')

Then I'm expecting to get 5 clusters, but when I plot the dendrogram

fig, ax = plt.subplots(figsize=(15, 20))  # set size
    ax = dendrogram(linkage_matrix, orientation="right")
    plt.tick_params( \
        axis='x',  # changes apply to the x-axis
        which='both',  # both major and minor ticks are affected
        bottom='off',  # ticks along the bottom edge are off
        top='off',  # ticks along the top edge are off
        labelbottom='off')

    plt.tight_layout()  # show plot with tight layout

    plt.savefig('ward_clusters.png', dpi=200)  # save figure as ward_clusters

I get the following graph

enter image description here

Based on the colors I can see 3 clusters, not 5! Am I misunderstanding the meaning of the dendrogram?


Solution

    • First of all, if you just want to make 5 clusters, just use labels (the line with fcluster you did not use).

    In labels : each point from your dataset is represented by a number. These numbers are the ids of your clusters.

    • If you want to use a dendogram, and plot 5 different clusters, then you'll have to "cut" your dendogram.

    Draw a vertical line at x=5 (around 5), consider that each dendogram on the left is independent.

    enter image description here

    Artificially, you cut your dendogram into 5 parts (or 5 clusters).

    To add some color to differentiate them, just adapt the following code (since you didn't provide your dataset, I used the iris dataset to show you one possible solution)

    from scipy.cluster.hierarchy import *
    from sklearn.datasets import load_iris
    from sklearn.metrics.pairwise import cosine_similarity
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    
    iris= load_iris()
    
    data = iris['data']
    df = pd.DataFrame(data, columns = iris['feature_names'])
    
    # Somehow, we have something equivalent to work with now
    dist_matrix = cosine_similarity(df)
    linkage_matrix = ward(dist_matrix)
    
    fig, ax = plt.subplots(figsize=(20, 10))
    
    #here just put 5 for the color_threshold, which correspond to the position of the vertical line
    ax = dendrogram(linkage_matrix, color_threshold =0.7)
    
    plt.tick_params( \
        axis='x',
        which='both',
        bottom='off',
        top='off',
        labelbottom='off')
    
    plt.show()
    

    enter image description here