Search code examples
pythonmatplotlibscipyhierarchical-clusteringdendrogram

Adding colorbars to clustered heatmaps


I am trying to replicate this type of plot (heatmap with colorbars as leaves) heatmap with colorbars as leaves]

This is what I've done so far

import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
from scipy.cluster.hierarchy import dendrogram, linkage
import scipy.cluster.hierarchy as sch
import scipy.spatial.distance as ssd

#read data
fid_df = pd.read_csv(fid_file, index_col=[0])

# scale data
def scale(x):
    return np.math.log2(x+1)
fid_df = fid_df.applymap(scale)

# clustering colums
data_1D_X = ssd.pdist(fid_df.T, 'euclidean')
X = sch.linkage(data_1D_X, method='ward')
# clustering rows
data_1D_Y = ssd.pdist(fid_df, 'cityblock')
Y = linkage(data_1D_Y, method='ward')
#plot first dendrogram
fig = plt.figure(figsize=(8, 8))

ax1 = fig.add_axes([0.09, 0.1, 0.2, 0.6])
Z1 = sch.dendrogram(Y, orientation='left')
ax1.set_xticks([])
ax1.set_yticks([])

# second dendrogram.
ax2 = fig.add_axes([0.3, 0.71, 0.6, 0.2])
Z2 = sch.dendrogram(X)
ax2.set_xticks([])
ax2.set_yticks([])

# plot matrix
axmatrix = fig.add_axes([0.3, 0.1, 0.6, 0.6])
# sorts based of clustering
idx1 = Z1['leaves']
idx2 = Z2['leaves']
D = fid_df.values[idx1, :]
D = D[:, idx2]
im = axmatrix.matshow(D, aspect='auto', origin='lower', cmap=plt.cm.YlGnBu)
axmatrix.set_xticks([])
axmatrix.set_yticks([])

Example: example

However, I need to add colorbars that would show the initial groups of rows and columns. Any idea how to do this?


Solution

  • Something like this?

    import matplotlib.pyplot as plt
    import numpy as np
    
    fig = plt.figure()
    ax1 = fig.add_axes((0, 0, 1, 0.9))
    ax2 = fig.add_axes((0, 0.9, 1, 0.1))
    gridY, gridX = np.mgrid[0:10:11 * 1j, 0:10:11 * 1j]
    ax1.pcolormesh(gridX, gridY, np.sqrt(gridX ** 2 + gridY ** 2))
    randCol = ['red', 'blue']
    for value in np.linspace(0, 10, 1001):
        ax2.axvline(value, color=randCol[np.random.default_rng().integers(2)])
    ax2.set_xlim((0, 10))
    ax2.tick_params(labelbottom=False, bottom=False, labelleft=False, left=False)
    fig.savefig('so.png', bbox_inches='tight')
    

    enter image description here