If I'm plotting a (correlation) dataframe with sns.clustermap
it automatically takes the dataframes multindex as labels and plots them right and below the clustermap.
How do I access these labels? I'm using clustermaps as an exploratory tool for large-ish datasets (100-200 entries) and I need the names for the entries in various clusters.
EXAMPLE:
elev = [1, 100, 10, 1000, 100, 10]
number = [1, 2, 3, 4, 5, 6]
name = ['foo', 'bar', 'baz', 'qux', 'quux', 'quuux']
idx = pd.MultiIndex.from_arrays([name, elev, number],
names=('name','elev', 'number'))
data = np.random.rand(20,6)
df = pd.DataFrame(data=data, columns=idx)
clustermap = sns.clustermap(df.corr())
gives
Now I'd say that theres two distinct clusters: the first two rows and the last 4 rows, so [foo-1-1, bar-100-2]
and [baz-10-3, qux-1000-4, quux-100-5, quuux-10-6]
.
How can I extract these (or the whole [foo-1-1, bar-100-2, baz-10-3, qux-1000-4, quux-100-5, quuux-10-6]
list)? With 100+ Entries, just writing them down by hand isn't really an option.
The documentation offers clustergrid.dendrogram_row.reordered_ind
but that just gives me the index numbers in the original dataframe. But I'm looking for something more like the output of df.columns
With this it seems to me like I'm getting into the right direction, but I can only extract to which cluster a given row belongs, when I let it form clusters automatically, but I'd like to define the clusters myself, visually.
As always with such things, the answer is out there, I just overlooked it.
This answer (pointed out by Trenton McKinney in comments) has the needed snipped in it:
ax_heatmap.yaxis.get_majorticklabels()
(I wouldn't have looked into ax_heatmap
to get to that...). So, continuing the MWE from the question:
labels = clustermap.ax_heatmap.yaxis.get_majorticklabels()
However, that's a list of
type(labels[0])
matplotlib.text.Text
so unless I'm missing something (again), it's not exactly straigtforward to use. However, that can simply be looped into something more usefull. Let's say I'm interested in the whole name (i.e. the complete former df multiindex) and the number:
labels_list = []
number_list = []
for i in labels:
i = str(i)
name_start = i.find('\'')+1
name_end = i.rfind('\'')
name = i[name_start:name_end]
number_start = name.rfind('-')+1
number = name[number_start:]
number = int(number)
labels_list.append(name)
number_list.append(number)
Now I've got two easily workable lists, one with full strings and one with ints.