Search code examples
pythonscipycluster-analysishierarchical-clustering

Label Scipy Dendrogram by Average Cluster Value


I have a dendrogram calculated from some data points labeled 0-9.

How do I retrieve which datapoints (0-9) are in each node from the output of scipy.cluster.hiearchy.dendrogram? I want to label each node by its average (x,y) value. I know I can retrieve the clusters using the clustering algorithm (Scikit learn agglomerative clustering for example) but I want to label the whole dendrogram by the average value in each node.

from matplotlib import pyplot as plt
from scipy.cluster.hierarchy import dendrogram, linkage
from scipy.spatial.distance import pdist, squareform
import numpy as np


np.random.seed(4711)  
X = np.random.uniform(-10,10,size=(10,2))
labels=[i for i in range(len(X))]
fig,ax=plt.subplots()
ax.scatter(X[:,0], X[:,1])

dist=squareform(pdist(X))

Y = linkage(squareform(dist), method='complete')
Z1 =dendrogram(Y, labels=labels, orientation='left')

For example, in the dendrogram above I would like to label the node which joins points 6,4 as the average of the x and y values of point 6 and point 4. (X[6]+X[4])/2 etc.

Dendrogram output

scipy.cluster.hiearchy.dendrogram outputs a dictionary with the color_list, icoord, dcoord, ivl, leaves, leaves_colorlist. I feel like I should be able to use ivl combined with leaves and the coordinates but I'm not sure how to interpret it/use it.

(documentation for dendrogram)

Thank you.


Solution

  • You can use the leaf_label_func parameter.

    from matplotlib import pyplot as plt
    from scipy.cluster.hierarchy import dendrogram, linkage
    from scipy.spatial.distance import pdist, squareform
    import numpy as np
    
    np.random.seed(4711)  
    # X will be interpreted as cluster centers
    X = np.random.uniform(-10,10,size=(10,2))
    def llf(id):
        return f'Cluster #{id} @ {X[id]}'
    Y = linkage(pdist(X), method='complete')
    Z1 =dendrogram(Y, leaf_label_func=llf, get_leaves=True, orientation='left')
    

    I also replaced a squareform(squareform(pdist(dist))) by pdist(dist).