Search code examples
plotscipyhierarchical-clusteringlinkagedendrogram

Relation between dendrogram plot coordinates and ClusterNodes in scipy


I'm looking for a way to get the coordinates of a cluster point in the dendrogram plot based on its ClusterNode return by to_tree.

Using scipy to build a dendogram from data such as:

X = data
Y = pdist(X)
Z = linkage(Y)
dend = dendrogram(Z)
rootnode, nodesList = to_tree(Z, rd=True)

What I would like to do is build a function get_coords(somClusterNode) that would return the tuple (x, y) specifying the position of the node in the plot.

Thanks to this answer, I managed to figure out how to get the position from the dendrogram return values, such as:

i, d = list(zip(dend['icoord'], dend['dcoord']))[-1]
x = 0.5 * sum(i[1:3])
y = d[1]
plt.plot(x, y, 'ro')

But I can figure out a relation between the nodesList ordering and the icoord/dcoord ordering in order to map one to the other.

Do you have any idea where I could look for ?

Thanks for your help !


Solution

  • Each dendrogram maps to only one tree of ClusterNodes, but any tree of ClusterNodes could map to an infinite number of dendrograms. Hence your mapping from node ID to (x,y) positions should probably just be another field in your dendrogram data structure instead of being a function of a ClusterNode. Instead of defining a function get_coords, I hence appends a dictionary to dend that maps node IDs to (x,y) coordinates. You can access the positions with

    x,y = dend['node_id_to_coord'][node_id] # node_id is an integer as returned by ClusterNode.id
    

    Code:

    import numpy as np
    import matplotlib.pyplot as plt
    from scipy.cluster.hierarchy import linkage, dendrogram, to_tree
    from scipy.spatial.distance import pdist
    
    # create some random data
    X = np.random.rand(10, 3)
    
    # get dendrogram
    Z = linkage(pdist(X), method="ward")
    dend = dendrogram(Z)
    
    # ----------------------------------------
    # get leave coordinates, which are at y == 0
    
    def flatten(l):
        return [item for sublist in l for item in sublist]
    X = flatten(dend['icoord'])
    Y = flatten(dend['dcoord'])
    leave_coords = [(x,y) for x,y in zip(X,Y) if y==0]
    
    # in the dendogram data structure,
    # leave ids are listed in ascending order according to their x-coordinate
    order = np.argsort([x for x,y in leave_coords])
    id_to_coord = dict(zip(dend['leaves'], [leave_coords[idx] for idx in order])) # <- main data structure
    
    # ----------------------------------------
    # get coordinates of other nodes
    
    # this should work but doesn't:
    
    # # traverse tree from leaves upwards and populate mapping ID -> (x,y);
    # # use linkage matrix to traverse the tree optimally
    # # (each row in the linkage matrix corresponds to a row in dend['icoord'] and dend['dcoord'])
    # root_node, node_list = to_tree(Z, rd=True)
    # for ii, (X, Y) in enumerate(zip(dend['icoord'], dend['dcoord'])):
    #     x = (X[1] + X[2]) / 2
    #     y = Y[1] # or Y[2]
    #     node_id = ii + len(dend['leaves'])
    #     id_to_coord[node_id] = (x, y)
    
    # so we need to do it the hard way:
    
    # map endpoint of each link to coordinates of parent node
    children_to_parent_coords = dict()
    for i, d in zip(dend['icoord'], dend['dcoord']):
        x = (i[1] + i[2]) / 2
        y = d[1] # or d[2]
        parent_coord = (x, y)
        left_coord = (i[0], d[0])
        right_coord = (i[-1], d[-1])
        children_to_parent_coords[(left_coord, right_coord)] = parent_coord
    
    # traverse tree from leaves upwards and populate mapping ID -> (x,y)
    root_node, node_list = to_tree(Z, rd=True)
    ids_left = range(len(dend['leaves']), len(node_list))
    
    while len(ids_left) > 0:
    
        for ii, node_id in enumerate(ids_left):
            node = node_list[node_id]
            if (node.left.id in id_to_coord) and (node.right.id in id_to_coord):
                left_coord = id_to_coord[node.left.id]
                right_coord = id_to_coord[node.right.id]
                id_to_coord[node_id] = children_to_parent_coords[(left_coord, right_coord)]
    
        ids_left = [node_id for node_id in range(len(node_list)) if not node_id in id_to_coord]
    
    # plot result on top of dendrogram
    ax = plt.gca()
    for node_id, (x, y) in id_to_coord.iteritems():
        if not node_list[node_id].is_leaf():
            ax.plot(x, y, 'ro')
            ax.annotate(str(node_id), (x, y), xytext=(0, -8),
                        textcoords='offset points',
                        va='top', ha='center')
    
    dend['node_id_to_coord'] = id_to_coord
    

    enter image description here