Search code examples
pythonmatplotlibd3.jsscikit-learnhierarchical-clustering

How to Parse Data from SciKitLearn Agglomerative Clustering


I feel like very similar versions of this question have been asked, and I have learned a lot while trying to solve this problem, but there is some (probably very basic) concept that I am missing.

Here are three very similar question/answers that are very good:

Extract path from root to leaf in sklearn's agglomerative clustering

Yield all root-to-leaf branches of a Binary Tree

How do you visualize a ward tree from sklearn.cluster.ward_tree?

And there is some great stuff on Mike Bostock's D3 Git: Mike Bostock's D3 Git Repo

Now the specifics of my situation:

I have done some analysis in Python using sklearn Agglomerative Clustering. I am generating the dendrograms I would like to see in MatplotLib:

T=7 T=7 Dendrogram

T=2 T=2 Dendrogram

Now I would like to add those dendrograms and some other functionality to a web site. I have built a site using Django.

I have used some D3 Javascript functionality already to implement a dynamic and interactive Tree Diagram like this one:

https://bl.ocks.org/d3noob/8375092

And I have made it so it loads the information for each branch from a json file. So it is both dynamic and interactive. Interactive Tree Diagram

Now I want to mimic some of that functionality with the info from the Agglomerative Clustering.

I want to:

  1. Make a dendrogram similar to the one from MatplotLib, and make it interactive such that there should be a slider that allows the user to change the T value, and the dendrogram should redraw.

    A. I am open to any suggestions. I can brute force a solution by simply recalculating the dendrogram in Python (as a module inside Django app), saving an image and loading that image on the javascript side in the template. I think there is probably a more elegant solution with D3 but I am running out of time to do research.

  2. Create an interactive Tree Diagram using the info from the clustering. I would like to see the dendrogram far more interactive as a tree. It seems like I should be able to use either the agglo_model.children_ or the linkage_matrix.

agglo_model.children_: 
[[ 35  36][ 13  18][ 19  20]...[ 22  69][ 33  34][ 14  32]]

or the linkage_matrix:
linkage_matrix: 
[[ 35.          36.           0.           2.        ]
 [ 22.          69.           1.73205081   4.        ]
...
 [ 50.          57.           4.47213595   2.        ]
 [  9.          41.           4.69041576   2.        ]
...
 [116.         126.          12.62713128  36.        ]
 [128.         129.          17.97652791  66.        ]]

The key piece I am missing is how to go from scikit to the following tree format for d3.js

    var treeData = [
  {
    "name": "Top Level",
    "parent": "null",
    "children": [
      {
        "name": "Level 2: A",
        "parent": "Top Level",
        "children": [
          {
            "name": "Son of A",
            "parent": "Level 2: A"
          },
          {
            "name": "Daughter of A",
            "parent": "Level 2: A"
          }
        ]
      },
      {
        "name": "Level 2: B",
        "parent": "Top Level"
      }
    ]
  }
];
  1. Show a clustering diagram on the web page. Basicially I'd like to mimic this page: Agglomerative Clustering and MatplotLib Diagrams and Dendrograms with interactive javascript.

Again - any suggestions appreciated.


Solution

  • There's a lot to this question but I think it boils down to how can one get the data from a linkage matrix into a d3 tree diagram?

    Given this python code:

    from scipy.cluster.hierarchy import dendrogram, linkage
    import numpy as np
    from matplotlib import pyplot as plt
       
    X = np.array([[5,3],
        [10,15],
        [15,12],
        [24,10],
        [30,30],
        [85,70],
        [71,80],
        [60,78],
        [70,55],
        [80,91],])
    
    linked = linkage(X, "single")
    
    labelList = range(0, 10)
    
    plt.figure(figsize=(10, 7))
    dendrogram(linked,
                orientation='top',
                labels=labelList,
                distance_sort='descending',
                show_leaf_counts=False)
    plt.show()
    

    Which generates this linkage matrix:

    [[ 1.          2.          5.83095189  2.        ]
     [ 3.         10.          9.21954446  3.        ]
     [ 6.          7.         11.18033989  2.        ]
     [ 0.         11.         13.          4.        ]
     [ 9.         12.         14.2126704   3.        ]
     [ 5.         14.         17.20465053  4.        ]
     [ 4.         13.         20.88061302  5.        ]
     [ 8.         15.         21.21320344  5.        ]
     [16.         17.         47.16990566 10.        ]]
    

    And this dendrogram:

    enter image description here

    How then can we then convert the linkage matrix into a data structure compatible for the d3 tree? Now it took me a bit to grok the structure of the matrix and how it was conveying hierarchy. This post and this post really explain it well. So, let's do a little data manipulation in Python and then JSON-ify it out. If you are using Django, you can return this JSON via API call.

    import json
    
    def create_tree(linked):
        ## inner func to recurvise-ly walk the linkage matrix
        def recurTree(tree):
            k = tree['name']
            ## no children for this node
            if k not in inter:
                return
            for n in inter[k]:
                ## build child nodes
                node = {
                    "name": n,
                    "parent": k,
                    "children": []
                }
                ## add to children
                tree['children'].append(node)
                ## next recursion
                recurTree(node)      
        
        num_rows, _ = linked.shape
        inter = {}
        i = 0
        // loop linked matrix convert to dictionary
        for row in linked:
            i += 1
            inter[float(i + num_rows)] = [row[0],row[1]]
    
        // start tree data structure
        tree = {
            "name": float(i + num_rows),
            "parent": None,
            "children": []
        }
        // start recursion
        recurTree(tree);
        return tree
    
    print(json.dumps(create_tree(linked), indent = 2))
    

    This produces JSON like so:

    {
      "name": 18.0,
      "parent": null,
      "children": [
        {
          "name": 16.0,
          "parent": 18.0,
          "children": [
            {
              "name": 4.0,
              "parent": 16.0,
              "children": []
            },
            {
              "name": 13.0,
              "parent": 16.0,
              "children": [
                {
                  "name": 0.0,
                  "parent": 13.0,
                  "children": []
                },
                {
                  "name": 11.0,
                  "parent": 13.0,
                  "children": [
                    {
                      "name": 3.0,
                      "parent": 11.0,
                      "children": []
                    },
                    {
                      "name": 10.0,
                      "parent": 11.0,
                      "children": [
                        {
                          "name": 1.0,
                          "parent": 10.0,
                          "children": []
                        },
                        {
                          "name": 2.0,
                          "parent": 10.0,
                          "children": []
                        }
                      ]
                    }
                  ]
                }
              ]
            }
          ]
        },
        {
          "name": 17.0,
          "parent": 18.0,
          "children": [
            {
              "name": 8.0,
              "parent": 17.0,
              "children": []
            },
            {
              "name": 15.0,
              "parent": 17.0,
              "children": [
                {
                  "name": 5.0,
                  "parent": 15.0,
                  "children": []
                },
                {
                  "name": 14.0,
                  "parent": 15.0,
                  "children": [
                    {
                      "name": 9.0,
                      "parent": 14.0,
                      "children": []
                    },
                    {
                      "name": 12.0,
                      "parent": 14.0,
                      "children": [
                        {
                          "name": 6.0,
                          "parent": 12.0,
                          "children": []
                        },
                        {
                          "name": 7.0,
                          "parent": 12.0,
                          "children": []
                        }
                      ]
                    }
                  ]
                }
              ]
            }
          ]
        }
      ]
    }
    

    If we then dump that into the d3 tree example (I made it vertical), you end up with this:

    enter image description here

    Running code for d3:

    <!DOCTYPE html>
    <html lang="en">
      <head>
        <meta charset="utf-8" />
    
        <title>Tree Example</title>
    
        <style>
          .node {
            cursor: pointer;
          }
    
          .node circle {
            fill: #fff;
            stroke: steelblue;
            stroke-width: 3px;
          }
    
          .node text {
            font: 12px sans-serif;
          }
    
          .link {
            fill: none;
            stroke: #ccc;
            stroke-width: 2px;
          }
        </style>
      </head>
    
      <body>
        <!-- load the d3.js library -->
        <script src="https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.17/d3.min.js"></script>
    
        <script>
          var treeData = {
            name: 18.0,
            parent: null,
            children: [
              {
                name: 16.0,
                parent: 18.0,
                children: [
                  {
                    name: 4.0,
                    parent: 16.0,
                    children: [],
                  },
                  {
                    name: 13.0,
                    parent: 16.0,
                    children: [
                      {
                        name: 0.0,
                        parent: 13.0,
                        children: [],
                      },
                      {
                        name: 11.0,
                        parent: 13.0,
                        children: [
                          {
                            name: 3.0,
                            parent: 11.0,
                            children: [],
                          },
                          {
                            name: 10.0,
                            parent: 11.0,
                            children: [
                              {
                                name: 1.0,
                                parent: 10.0,
                                children: [],
                              },
                              {
                                name: 2.0,
                                parent: 10.0,
                                children: [],
                              },
                            ],
                          },
                        ],
                      },
                    ],
                  },
                ],
              },
              {
                name: 17.0,
                parent: 18.0,
                children: [
                  {
                    name: 8.0,
                    parent: 17.0,
                    children: [],
                  },
                  {
                    name: 15.0,
                    parent: 17.0,
                    children: [
                      {
                        name: 5.0,
                        parent: 15.0,
                        children: [],
                      },
                      {
                        name: 14.0,
                        parent: 15.0,
                        children: [
                          {
                            name: 9.0,
                            parent: 14.0,
                            children: [],
                          },
                          {
                            name: 12.0,
                            parent: 14.0,
                            children: [
                              {
                                name: 6.0,
                                parent: 12.0,
                                children: [],
                              },
                              {
                                name: 7.0,
                                parent: 12.0,
                                children: [],
                              },
                            ],
                          },
                        ],
                      },
                    ],
                  },
                ],
              },
            ],
          };
    
          // ************** Generate the tree diagram    *****************
          var margin = { top: 40, right: 120, bottom: 20, left: 120 },
            width = 960 - margin.right - margin.left,
            height = 600 - margin.top - margin.bottom;
    
          var i = 0;
    
          var tree = d3.layout.tree().size([height, width]);
    
          var diagonal = d3.svg.diagonal().projection(function (d) {
            return [d.x, d.y];
          });
    
          var svg = d3
            .select('body')
            .append('svg')
            .attr('width', width + margin.right + margin.left)
            .attr('height', height + margin.top + margin.bottom)
            .append('g')
            .attr('transform', 'translate(' + margin.left + ',' + margin.top + ')');
    
          root = treeData;
    
          update(root);
    
          function update(source) {
            // Compute the new tree layout.
            var nodes = tree.nodes(root).reverse(),
              links = tree.links(nodes);
    
            // Normalize for fixed-depth.
            nodes.forEach(function (d) {
              d.y = d.depth * 100;
            });
    
            // Declare the nodes…
            var node = svg.selectAll('g.node').data(nodes, function (d) {
              return d.id || (d.id = ++i);
            });
    
            // Enter the nodes.
            var nodeEnter = node
              .enter()
              .append('g')
              .attr('class', 'node')
              .attr('transform', function (d) {
                return 'translate(' + d.x + ',' + d.y + ')';
              });
    
            nodeEnter.append('circle').attr('r', 10).style('fill', '#fff');
    
            nodeEnter
              .append('text')
              .attr('y', function (d) {
                return d.children || d._children ? -18 : 18;
              })
              .attr('dy', '.35em')
              .attr('text-anchor', 'middle')
              .text(function (d) {
                return d.name;
              })
              .style('fill-opacity', 1);
    
            // Declare the links…
            var link = svg.selectAll('path.link').data(links, function (d) {
              return d.target.id;
            });
    
            // Enter the links.
            link
              .enter()
              .insert('path', 'g')
              .attr('class', 'link')
              .attr('d', diagonal);
          }
        </script>
      </body>
    </html>