Search code examples
pythonscikit-learnhierarchical-clustering

Extract path from root to leaf in sklearn's agglomerative clustering


Given some specific leaf node of the agglomerative clustering created by sklearn.AgglomerativeClustering, I am trying to identify the path from the root node (all data points) to the given leaf node and for each intermediate step (internal node of the tree) the list of corresponding data points, see the example below.

enter image description here

In this example, I consider five data points and focus on the point 3 in such a way that I want to extract the instances considered in each step starting at the root and ending at the leaf 3, so the desired result would be [[1,2,3,4,5],[1,3,4,5],[3,4],[3]]. How could I achieve this with sklearn (or if this is not possible with a different library)?


Solution

  • Code below first find all ancestors of focus point (using find_ancestor function below), then finds and add all descendents (find_descendent) of each ancestor.

    First loading and training data:

    iris = load_iris()
    N = 10
    x = iris.data[:N]
    model = AgglomerativeClustering(compute_full_tree=True).fit(x)
    

    Here is the main code:

    ans = []
    for a in find_ancestor(3)[::-1]:
        ans.append(find_descendent(a))
    print(ans)
    

    Which outputs in my case:

    [[1, 9, 8, 6, 2, 3, 5, 7, 0, 4],
     [1, 9, 8, 6, 2, 3],
     [8, 6, 2, 3],
     [6, 2, 3],
     [2, 3],
     [3]]
    

    To understand code of find_ancestor, please remember that 2 childs of a non-leaf node with index i are at model.children_[i]

    def find_ancestor(target):
        for ind,pair in enumerate(model.children_):
            if target in pair:
                return [target]+find_ancestor(N+ind)
        return [ind+N]
    

    The recursive find_descendent uses mem to keep it's output in memory so they don't get needlessly re-computed.

    mem = {}
    def find_descendent(node):
        global mem
        if node in mem: return mem[node]
        if node<N: return [node]
        pair = model.children_[node-N]
        mem[node] = find_descendent(pair[0])+find_descendent(pair[1])
        return mem[node]