Search code examples
pythondataframeseabornheatmapmulti-index

How to arrange y-labels in seaborn clustermap when using a multiindex dataframe?


I am trying to customize y-labels of a clustermap from seaborn with a multiindex dataframe. So I have a dataframe that looks like this :

                    Col1    Col2    ...
Idx1.A    Idx2.a    1.05    1.51    ...
          Idx2.b    0.94    0.88    ...
Idx1.B    Idx2.c    1.09    1.20    ...
          Idx2.d    0.90    0.79    ...
   ...       ...     ...     ...    ...

The goal is to have the same y-labels like that, where in my example Idx1 would be the seasons, Idx2 would be the months and the Cols would be the years (except that it's a clustermap, not a heatmap - so I think the functions from the seaborn classes are different when customizing the ticks -, though clustermap just "add" a hierarchic clustering on a heatmap over rows or columns): enter image description here My code :

def do_clustermap():
    with open('/home/Documents/myfile.csv', 'r') as f:
        df = pd.read_csv(f, index_col=[0, 1], sep='\t')

        g = sns.clustermap(df, center=1, row_cluster=False, cmap="YlGnBu", yticklabels=True, xticklabels=True, linewidths=0.004)
        g.ax_heatmap.yaxis.set_ticks_position("left")

        plt.setp(g.ax_heatmap.xaxis.get_majorticklabels(), fontsize=4)
        plt.setp(g.ax_heatmap.yaxis.get_majorticklabels(), fontsize=4)
        plt.show()

I tried to follow the answers from this thread but it gives this message :

UserWarning: Clustering large matrix with scipy. Installing `fastcluster` may give better performance.
  warnings.warn(msg)
Traceback (most recent call last):
  File "/home/ju/PycharmProjects/stage/figures.py", line 24, in <module>
    do_heatmap()
  File "/home/ju/PycharmProjects/stage/figures.py", line 13, in do_heatmap
    ax = sns.clustermap(df, center=1, row_cluster=False, cmap="YlGnBu", yticklabels=True, xticklabels=True, linewidths=0.004)
  File "/home/ju/PycharmProjects/stage/venv/lib/python3.6/site-packages/seaborn/_decorators.py", line 46, in inner_f
    return f(**kwargs)
  File "/home/ju/PycharmProjects/stage/venv/lib/python3.6/site-packages/seaborn/matrix.py", line 1412, in clustermap
    tree_kws=tree_kws, **kwargs)
  File "/home/ju/PycharmProjects/stage/venv/lib/python3.6/site-packages/seaborn/matrix.py", line 1223, in plot
    tree_kws=tree_kws)
  File "/home/ju/PycharmProjects/stage/venv/lib/python3.6/site-packages/seaborn/matrix.py", line 1079, in plot_dendrograms
    tree_kws=tree_kws
  File "/home/ju/PycharmProjects/stage/venv/lib/python3.6/site-packages/seaborn/_decorators.py", line 46, in inner_f
    return f(**kwargs)
  File "/home/ju/PycharmProjects/stage/venv/lib/python3.6/site-packages/seaborn/matrix.py", line 776, in dendrogram
    label=label, rotate=rotate)
  File "/home/ju/PycharmProjects/stage/venv/lib/python3.6/site-packages/seaborn/matrix.py", line 584, in __init__
    self.linkage = self.calculated_linkage
  File "/home/ju/PycharmProjects/stage/venv/lib/python3.6/site-packages/seaborn/matrix.py", line 651, in calculated_linkage
    return self._calculate_linkage_scipy()
  File "/home/ju/PycharmProjects/stage/venv/lib/python3.6/site-packages/seaborn/matrix.py", line 620, in _calculate_linkage_scipy
    metric=self.metric)
  File "/home/ju/PycharmProjects/stage/venv/lib/python3.6/site-packages/scipy/cluster/hierarchy.py", line 1038, in linkage
    y = _convert_to_double(np.asarray(y, order='c'))
  File "/home/ju/PycharmProjects/stage/venv/lib/python3.6/site-packages/scipy/cluster/hierarchy.py", line 1560, in _convert_to_double
    X = X.astype(np.double)
ValueError: could not convert string to float: 'Col1'

Anyone has an idea ? Here a small example of the file I'm working with:

        Robert  Jean    Lulu
Bar a   1.05    1.52    1.16
Bar b   0.94    0.49    0.83
Foo c   1.09    1.22    1.44
Foo d   0.92    0.79    0.55
Hop e   0.62    0.82    0.68
Hop f   0.52    0.18    0.31
Hop g   0.93    1.15    1.11

Solution

  • Here is some code creating a minimal example similar to the given data.

    import matplotlib.pyplot as plt
    import seaborn as sns
    import pandas as pd
    import numpy as np
    
    df = pd.DataFrame({'Idx1': ['Bar', 'Bar', 'Foo', 'Foo', 'Hop', 'Hop', 'Hop'],
                       'Idx2': ['a', 'b', 'c', 'd', 'e', 'f', 'g'],
                       'Col1': np.random.rand(7),
                       'Col2': np.random.rand(7)})
    df = df.set_index(['Idx1', 'Idx2'])
    
    g = sns.clustermap(df, center=1, row_cluster=False, cmap="YlGnBu", yticklabels=True, xticklabels=True, linewidths=0.004)
    g.ax_heatmap.yaxis.set_ticks_position("left")
    
    plt.setp(g.ax_heatmap.xaxis.get_majorticklabels(), fontsize=10)
    plt.setp(g.ax_heatmap.yaxis.get_majorticklabels(), fontsize=10)
    plt.show()
    

    The dataframe looks like:

                   Col1      Col2
    Idx1 Idx2                    
    Bar  a     0.366961  0.253956
         b     0.320457  0.807694
    Foo  c     0.293184  0.337154
         d     0.868155  0.661968
    Hop  e     0.908930  0.406291
         f     0.670220  0.668903
         g     0.683821  0.476246
    

    With seaborn 0.11.1, matplotlib 3.4.2, pandas 1.2.4 and scipy 1.6.3 following plot is generated:

    clustermap example with two indices

    An integration with the linked code could look like the following. Some distances will need to be adjusted depending on the

    import matplotlib.pyplot as plt
    from itertools import groupby
    import seaborn as sns
    import pandas as pd
    import numpy as np
    
    def add_line(ax, xpos, ypos):
        line = plt.Line2D([ypos, ypos+ .2], [xpos, xpos], color='black', transform=ax.transAxes)
        line.set_clip_on(False)
        ax.add_line(line)
    
    def label_len(my_index,level):
        labels = my_index.get_level_values(level)
        return [(k, sum(1 for i in g)) for k,g in groupby(labels)]
    
    def label_group_bar_table(ax, df):
        xpos = -.2
        scale = 1./df.index.size
        for level in range(df.index.nlevels):
            pos = df.index.size
            for label, rpos in label_len(df.index,level):
                add_line(ax, pos*scale, xpos)
                pos -= rpos
                lypos = (pos + .5 * rpos)*scale
                ax.text(xpos+.1, lypos, label, ha='center', transform=ax.transAxes)
            add_line(ax, pos*scale , xpos)
            xpos -= .2
    
    df = pd.DataFrame({'Idx1': ['Bar', 'Bar', 'Foo', 'Foo', 'Hop', 'Hop', 'Hop'],
                       'Idx2': ['a', 'b', 'c', 'd', 'e', 'f', 'g'],
                       'Col1': np.random.rand(7),
                       'Col2': np.random.rand(7)})
    df = df.set_index(['Idx2', 'Idx1'])
    
    g = sns.clustermap(df, center=1, row_cluster=False, cmap="YlGnBu", yticklabels=True, xticklabels=True, linewidths=0.004, figsize=(10,5))
    g.ax_heatmap.yaxis.set_ticks_position("left")
    
    plt.setp(g.ax_heatmap.xaxis.get_majorticklabels(), fontsize=10)
    g.ax_heatmap.set_yticks([])
    label_group_bar_table(g.ax_heatmap, df)
    g.fig.subplots_adjust(left=0.15)
    plt.show()
    

    sns.clustermap with double indices, custom labeling