I feel like very similar versions of this question have been asked, and I have learned a lot while trying to solve this problem, but there is some (probably very basic) concept that I am missing.
Here are three very similar question/answers that are very good:
Extract path from root to leaf in sklearn's agglomerative clustering
Yield all root-to-leaf branches of a Binary Tree
How do you visualize a ward tree from sklearn.cluster.ward_tree?
And there is some great stuff on Mike Bostock's D3 Git: Mike Bostock's D3 Git Repo
Now the specifics of my situation:
I have done some analysis in Python using sklearn Agglomerative Clustering. I am generating the dendrograms I would like to see in MatplotLib:
T=7 T=7 Dendrogram
T=2 T=2 Dendrogram
Now I would like to add those dendrograms and some other functionality to a web site. I have built a site using Django.
I have used some D3 Javascript functionality already to implement a dynamic and interactive Tree Diagram like this one:
https://bl.ocks.org/d3noob/8375092
And I have made it so it loads the information for each branch from a json file. So it is both dynamic and interactive. Interactive Tree Diagram
Now I want to mimic some of that functionality with the info from the Agglomerative Clustering.
I want to:
Make a dendrogram similar to the one from MatplotLib, and make it interactive such that there should be a slider that allows the user to change the T value, and the dendrogram should redraw.
A. I am open to any suggestions. I can brute force a solution by simply recalculating the dendrogram in Python (as a module inside Django app), saving an image and loading that image on the javascript side in the template. I think there is probably a more elegant solution with D3 but I am running out of time to do research.
Create an interactive Tree Diagram using the info from the clustering. I would like to see the dendrogram far more interactive as a tree. It seems like I should be able to use either the agglo_model.children_ or the linkage_matrix.
agglo_model.children_:
[[ 35 36][ 13 18][ 19 20]...[ 22 69][ 33 34][ 14 32]]
or the linkage_matrix:
linkage_matrix:
[[ 35. 36. 0. 2. ]
[ 22. 69. 1.73205081 4. ]
...
[ 50. 57. 4.47213595 2. ]
[ 9. 41. 4.69041576 2. ]
...
[116. 126. 12.62713128 36. ]
[128. 129. 17.97652791 66. ]]
The key piece I am missing is how to go from scikit to the following tree format for d3.js
var treeData = [
{
"name": "Top Level",
"parent": "null",
"children": [
{
"name": "Level 2: A",
"parent": "Top Level",
"children": [
{
"name": "Son of A",
"parent": "Level 2: A"
},
{
"name": "Daughter of A",
"parent": "Level 2: A"
}
]
},
{
"name": "Level 2: B",
"parent": "Top Level"
}
]
}
];
Again - any suggestions appreciated.
There's a lot to this question but I think it boils down to how can one get the data from a linkage matrix into a d3
tree diagram?
Given this python code:
from scipy.cluster.hierarchy import dendrogram, linkage
import numpy as np
from matplotlib import pyplot as plt
X = np.array([[5,3],
[10,15],
[15,12],
[24,10],
[30,30],
[85,70],
[71,80],
[60,78],
[70,55],
[80,91],])
linked = linkage(X, "single")
labelList = range(0, 10)
plt.figure(figsize=(10, 7))
dendrogram(linked,
orientation='top',
labels=labelList,
distance_sort='descending',
show_leaf_counts=False)
plt.show()
Which generates this linkage matrix:
[[ 1. 2. 5.83095189 2. ]
[ 3. 10. 9.21954446 3. ]
[ 6. 7. 11.18033989 2. ]
[ 0. 11. 13. 4. ]
[ 9. 12. 14.2126704 3. ]
[ 5. 14. 17.20465053 4. ]
[ 4. 13. 20.88061302 5. ]
[ 8. 15. 21.21320344 5. ]
[16. 17. 47.16990566 10. ]]
And this dendrogram:
How then can we then convert the linkage matrix into a data structure compatible for the d3
tree? Now it took me a bit to grok the structure of the matrix and how it was conveying hierarchy. This post and this post really explain it well. So, let's do a little data manipulation in Python and then JSON-ify it out. If you are using Django, you can return this JSON via API call.
import json
def create_tree(linked):
## inner func to recurvise-ly walk the linkage matrix
def recurTree(tree):
k = tree['name']
## no children for this node
if k not in inter:
return
for n in inter[k]:
## build child nodes
node = {
"name": n,
"parent": k,
"children": []
}
## add to children
tree['children'].append(node)
## next recursion
recurTree(node)
num_rows, _ = linked.shape
inter = {}
i = 0
// loop linked matrix convert to dictionary
for row in linked:
i += 1
inter[float(i + num_rows)] = [row[0],row[1]]
// start tree data structure
tree = {
"name": float(i + num_rows),
"parent": None,
"children": []
}
// start recursion
recurTree(tree);
return tree
print(json.dumps(create_tree(linked), indent = 2))
This produces JSON like so:
{
"name": 18.0,
"parent": null,
"children": [
{
"name": 16.0,
"parent": 18.0,
"children": [
{
"name": 4.0,
"parent": 16.0,
"children": []
},
{
"name": 13.0,
"parent": 16.0,
"children": [
{
"name": 0.0,
"parent": 13.0,
"children": []
},
{
"name": 11.0,
"parent": 13.0,
"children": [
{
"name": 3.0,
"parent": 11.0,
"children": []
},
{
"name": 10.0,
"parent": 11.0,
"children": [
{
"name": 1.0,
"parent": 10.0,
"children": []
},
{
"name": 2.0,
"parent": 10.0,
"children": []
}
]
}
]
}
]
}
]
},
{
"name": 17.0,
"parent": 18.0,
"children": [
{
"name": 8.0,
"parent": 17.0,
"children": []
},
{
"name": 15.0,
"parent": 17.0,
"children": [
{
"name": 5.0,
"parent": 15.0,
"children": []
},
{
"name": 14.0,
"parent": 15.0,
"children": [
{
"name": 9.0,
"parent": 14.0,
"children": []
},
{
"name": 12.0,
"parent": 14.0,
"children": [
{
"name": 6.0,
"parent": 12.0,
"children": []
},
{
"name": 7.0,
"parent": 12.0,
"children": []
}
]
}
]
}
]
}
]
}
]
}
If we then dump that into the d3
tree example (I made it vertical), you end up with this:
Running code for d3
:
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8" />
<title>Tree Example</title>
<style>
.node {
cursor: pointer;
}
.node circle {
fill: #fff;
stroke: steelblue;
stroke-width: 3px;
}
.node text {
font: 12px sans-serif;
}
.link {
fill: none;
stroke: #ccc;
stroke-width: 2px;
}
</style>
</head>
<body>
<!-- load the d3.js library -->
<script src="https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.17/d3.min.js"></script>
<script>
var treeData = {
name: 18.0,
parent: null,
children: [
{
name: 16.0,
parent: 18.0,
children: [
{
name: 4.0,
parent: 16.0,
children: [],
},
{
name: 13.0,
parent: 16.0,
children: [
{
name: 0.0,
parent: 13.0,
children: [],
},
{
name: 11.0,
parent: 13.0,
children: [
{
name: 3.0,
parent: 11.0,
children: [],
},
{
name: 10.0,
parent: 11.0,
children: [
{
name: 1.0,
parent: 10.0,
children: [],
},
{
name: 2.0,
parent: 10.0,
children: [],
},
],
},
],
},
],
},
],
},
{
name: 17.0,
parent: 18.0,
children: [
{
name: 8.0,
parent: 17.0,
children: [],
},
{
name: 15.0,
parent: 17.0,
children: [
{
name: 5.0,
parent: 15.0,
children: [],
},
{
name: 14.0,
parent: 15.0,
children: [
{
name: 9.0,
parent: 14.0,
children: [],
},
{
name: 12.0,
parent: 14.0,
children: [
{
name: 6.0,
parent: 12.0,
children: [],
},
{
name: 7.0,
parent: 12.0,
children: [],
},
],
},
],
},
],
},
],
},
],
};
// ************** Generate the tree diagram *****************
var margin = { top: 40, right: 120, bottom: 20, left: 120 },
width = 960 - margin.right - margin.left,
height = 600 - margin.top - margin.bottom;
var i = 0;
var tree = d3.layout.tree().size([height, width]);
var diagonal = d3.svg.diagonal().projection(function (d) {
return [d.x, d.y];
});
var svg = d3
.select('body')
.append('svg')
.attr('width', width + margin.right + margin.left)
.attr('height', height + margin.top + margin.bottom)
.append('g')
.attr('transform', 'translate(' + margin.left + ',' + margin.top + ')');
root = treeData;
update(root);
function update(source) {
// Compute the new tree layout.
var nodes = tree.nodes(root).reverse(),
links = tree.links(nodes);
// Normalize for fixed-depth.
nodes.forEach(function (d) {
d.y = d.depth * 100;
});
// Declare the nodes…
var node = svg.selectAll('g.node').data(nodes, function (d) {
return d.id || (d.id = ++i);
});
// Enter the nodes.
var nodeEnter = node
.enter()
.append('g')
.attr('class', 'node')
.attr('transform', function (d) {
return 'translate(' + d.x + ',' + d.y + ')';
});
nodeEnter.append('circle').attr('r', 10).style('fill', '#fff');
nodeEnter
.append('text')
.attr('y', function (d) {
return d.children || d._children ? -18 : 18;
})
.attr('dy', '.35em')
.attr('text-anchor', 'middle')
.text(function (d) {
return d.name;
})
.style('fill-opacity', 1);
// Declare the links…
var link = svg.selectAll('path.link').data(links, function (d) {
return d.target.id;
});
// Enter the links.
link
.enter()
.insert('path', 'g')
.attr('class', 'link')
.attr('d', diagonal);
}
</script>
</body>
</html>