Search code examples
pythonseaborn

How to extract the labels from sns.clustermap


If I'm plotting a (correlation) dataframe with sns.clustermap it automatically takes the dataframes multindex as labels and plots them right and below the clustermap.

How do I access these labels? I'm using clustermaps as an exploratory tool for large-ish datasets (100-200 entries) and I need the names for the entries in various clusters.

EXAMPLE:

elev = [1, 100, 10, 1000, 100, 10]
number = [1, 2, 3, 4, 5, 6]
name = ['foo', 'bar', 'baz', 'qux', 'quux', 'quuux']
idx = pd.MultiIndex.from_arrays([name, elev, number], 
                                 names=('name','elev', 'number'))
data = np.random.rand(20,6)
df = pd.DataFrame(data=data, columns=idx)

clustermap = sns.clustermap(df.corr())

gives

sample plot

Now I'd say that theres two distinct clusters: the first two rows and the last 4 rows, so [foo-1-1, bar-100-2] and [baz-10-3, qux-1000-4, quux-100-5, quuux-10-6].

How can I extract these (or the whole [foo-1-1, bar-100-2, baz-10-3, qux-1000-4, quux-100-5, quuux-10-6] list)? With 100+ Entries, just writing them down by hand isn't really an option.

The documentation offers clustergrid.dendrogram_row.reordered_ind but that just gives me the index numbers in the original dataframe. But I'm looking for something more like the output of df.columns

With this it seems to me like I'm getting into the right direction, but I can only extract to which cluster a given row belongs, when I let it form clusters automatically, but I'd like to define the clusters myself, visually.


Solution

  • As always with such things, the answer is out there, I just overlooked it.

    This answer (pointed out by Trenton McKinney in comments) has the needed snipped in it:

    ax_heatmap.yaxis.get_majorticklabels()
    

    (I wouldn't have looked into ax_heatmap to get to that...). So, continuing the MWE from the question:

    labels = clustermap.ax_heatmap.yaxis.get_majorticklabels()
    

    However, that's a list of

    type(labels[0])
    matplotlib.text.Text
    

    so unless I'm missing something (again), it's not exactly straigtforward to use. However, that can simply be looped into something more usefull. Let's say I'm interested in the whole name (i.e. the complete former df multiindex) and the number:

    labels_list = []
    number_list = []
    for i in labels:
        i = str(i)
        name_start = i.find('\'')+1
        name_end = i.rfind('\'')
        name = i[name_start:name_end]
        number_start = name.rfind('-')+1
        number = name[number_start:]
        number = int(number)
        labels_list.append(name)
        number_list.append(number)
    

    Now I've got two easily workable lists, one with full strings and one with ints.