I am trying to customize y-labels of a clustermap from seaborn with a multiindex dataframe. So I have a dataframe that looks like this :
Col1 Col2 ...
Idx1.A Idx2.a 1.05 1.51 ...
Idx2.b 0.94 0.88 ...
Idx1.B Idx2.c 1.09 1.20 ...
Idx2.d 0.90 0.79 ...
... ... ... ... ...
The goal is to have the same y-labels like that, where in my example Idx1 would be the seasons, Idx2 would be the months and the Cols would be the years (except that it's a clustermap, not a heatmap - so I think the functions from the seaborn classes are different when customizing the ticks -, though clustermap just "add" a hierarchic clustering on a heatmap over rows or columns): My code :
def do_clustermap():
with open('/home/Documents/myfile.csv', 'r') as f:
df = pd.read_csv(f, index_col=[0, 1], sep='\t')
g = sns.clustermap(df, center=1, row_cluster=False, cmap="YlGnBu", yticklabels=True, xticklabels=True, linewidths=0.004)
g.ax_heatmap.yaxis.set_ticks_position("left")
plt.setp(g.ax_heatmap.xaxis.get_majorticklabels(), fontsize=4)
plt.setp(g.ax_heatmap.yaxis.get_majorticklabels(), fontsize=4)
plt.show()
I tried to follow the answers from this thread but it gives this message :
UserWarning: Clustering large matrix with scipy. Installing `fastcluster` may give better performance.
warnings.warn(msg)
Traceback (most recent call last):
File "/home/ju/PycharmProjects/stage/figures.py", line 24, in <module>
do_heatmap()
File "/home/ju/PycharmProjects/stage/figures.py", line 13, in do_heatmap
ax = sns.clustermap(df, center=1, row_cluster=False, cmap="YlGnBu", yticklabels=True, xticklabels=True, linewidths=0.004)
File "/home/ju/PycharmProjects/stage/venv/lib/python3.6/site-packages/seaborn/_decorators.py", line 46, in inner_f
return f(**kwargs)
File "/home/ju/PycharmProjects/stage/venv/lib/python3.6/site-packages/seaborn/matrix.py", line 1412, in clustermap
tree_kws=tree_kws, **kwargs)
File "/home/ju/PycharmProjects/stage/venv/lib/python3.6/site-packages/seaborn/matrix.py", line 1223, in plot
tree_kws=tree_kws)
File "/home/ju/PycharmProjects/stage/venv/lib/python3.6/site-packages/seaborn/matrix.py", line 1079, in plot_dendrograms
tree_kws=tree_kws
File "/home/ju/PycharmProjects/stage/venv/lib/python3.6/site-packages/seaborn/_decorators.py", line 46, in inner_f
return f(**kwargs)
File "/home/ju/PycharmProjects/stage/venv/lib/python3.6/site-packages/seaborn/matrix.py", line 776, in dendrogram
label=label, rotate=rotate)
File "/home/ju/PycharmProjects/stage/venv/lib/python3.6/site-packages/seaborn/matrix.py", line 584, in __init__
self.linkage = self.calculated_linkage
File "/home/ju/PycharmProjects/stage/venv/lib/python3.6/site-packages/seaborn/matrix.py", line 651, in calculated_linkage
return self._calculate_linkage_scipy()
File "/home/ju/PycharmProjects/stage/venv/lib/python3.6/site-packages/seaborn/matrix.py", line 620, in _calculate_linkage_scipy
metric=self.metric)
File "/home/ju/PycharmProjects/stage/venv/lib/python3.6/site-packages/scipy/cluster/hierarchy.py", line 1038, in linkage
y = _convert_to_double(np.asarray(y, order='c'))
File "/home/ju/PycharmProjects/stage/venv/lib/python3.6/site-packages/scipy/cluster/hierarchy.py", line 1560, in _convert_to_double
X = X.astype(np.double)
ValueError: could not convert string to float: 'Col1'
Anyone has an idea ? Here a small example of the file I'm working with:
Robert Jean Lulu
Bar a 1.05 1.52 1.16
Bar b 0.94 0.49 0.83
Foo c 1.09 1.22 1.44
Foo d 0.92 0.79 0.55
Hop e 0.62 0.82 0.68
Hop f 0.52 0.18 0.31
Hop g 0.93 1.15 1.11
Here is some code creating a minimal example similar to the given data.
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
df = pd.DataFrame({'Idx1': ['Bar', 'Bar', 'Foo', 'Foo', 'Hop', 'Hop', 'Hop'],
'Idx2': ['a', 'b', 'c', 'd', 'e', 'f', 'g'],
'Col1': np.random.rand(7),
'Col2': np.random.rand(7)})
df = df.set_index(['Idx1', 'Idx2'])
g = sns.clustermap(df, center=1, row_cluster=False, cmap="YlGnBu", yticklabels=True, xticklabels=True, linewidths=0.004)
g.ax_heatmap.yaxis.set_ticks_position("left")
plt.setp(g.ax_heatmap.xaxis.get_majorticklabels(), fontsize=10)
plt.setp(g.ax_heatmap.yaxis.get_majorticklabels(), fontsize=10)
plt.show()
The dataframe looks like:
Col1 Col2
Idx1 Idx2
Bar a 0.366961 0.253956
b 0.320457 0.807694
Foo c 0.293184 0.337154
d 0.868155 0.661968
Hop e 0.908930 0.406291
f 0.670220 0.668903
g 0.683821 0.476246
With seaborn 0.11.1, matplotlib 3.4.2, pandas 1.2.4 and scipy 1.6.3 following plot is generated:
An integration with the linked code could look like the following. Some distances will need to be adjusted depending on the
import matplotlib.pyplot as plt
from itertools import groupby
import seaborn as sns
import pandas as pd
import numpy as np
def add_line(ax, xpos, ypos):
line = plt.Line2D([ypos, ypos+ .2], [xpos, xpos], color='black', transform=ax.transAxes)
line.set_clip_on(False)
ax.add_line(line)
def label_len(my_index,level):
labels = my_index.get_level_values(level)
return [(k, sum(1 for i in g)) for k,g in groupby(labels)]
def label_group_bar_table(ax, df):
xpos = -.2
scale = 1./df.index.size
for level in range(df.index.nlevels):
pos = df.index.size
for label, rpos in label_len(df.index,level):
add_line(ax, pos*scale, xpos)
pos -= rpos
lypos = (pos + .5 * rpos)*scale
ax.text(xpos+.1, lypos, label, ha='center', transform=ax.transAxes)
add_line(ax, pos*scale , xpos)
xpos -= .2
df = pd.DataFrame({'Idx1': ['Bar', 'Bar', 'Foo', 'Foo', 'Hop', 'Hop', 'Hop'],
'Idx2': ['a', 'b', 'c', 'd', 'e', 'f', 'g'],
'Col1': np.random.rand(7),
'Col2': np.random.rand(7)})
df = df.set_index(['Idx2', 'Idx1'])
g = sns.clustermap(df, center=1, row_cluster=False, cmap="YlGnBu", yticklabels=True, xticklabels=True, linewidths=0.004, figsize=(10,5))
g.ax_heatmap.yaxis.set_ticks_position("left")
plt.setp(g.ax_heatmap.xaxis.get_majorticklabels(), fontsize=10)
g.ax_heatmap.set_yticks([])
label_group_bar_table(g.ax_heatmap, df)
g.fig.subplots_adjust(left=0.15)
plt.show()