Search code examples
pythonrmatplotlibseabornhierarchical-clustering

How to get complexly annotated heatmaps/clustermaps on Seaborn/Matplotlib?


I am working with tumour image expression data for a bunch of patients, for each patient I have a list of extracted image features of their tumour. I have clustered the patients and features using Hierarchical Agglomerative Clustering, and plotted it with .clustermap with Seaborn. This is what I have so far:

Now, each patient has a bunch of categorical information associated with it, these are cancer subtype(A,B,C,D), T stage (1,2,3,4), N stage(0,1,2,3), M stage(0,1) and well, the cluster they belong to from HAC(1,2,3,...). Moreover, each image feature belongs to a different class as well. I would like to display this categorical information on each axis (I am aware of {row, col}_colors. Essentially I am trying to recreate the below plot and I am wondering if it is at all possible with matplotlib/seaborn in Python.

Also, what do you think the authors of this figure used to generate it, was done back in 2014. R?

My code with some random data:

# Random dummy data
np_zfeatures = np.random.random((420, 1218)) # example matrix of z-scored features [patients, features]
patient_T_stage = np.random.randint(low=1, high=5, size=(420,))
patient_N_stage = np.random.randint(low=0, high=4, size=(420,))
patient_M_stage = np.random.randint(low=0, high=2, size=(420,))
patient_O_stage = np.random.randint(low=0, high=5, size=(420,))
patient_subtype = np.random.randint(low=0, high=5, size=(420,))
feature_class = np.random.randint(low=0, high=5, size=(1218,))       # There's 5 categories of features (first order, shape, textural, wavelet, LoG)

# HAC clustering (compute linkage matrices)
method = 'ward'
feature_links = scipy.cluster.hierarchy.linkage(np_zfeatures, method=method, metric='euclidean')
patient_links = scipy.cluster.hierarchy.linkage(np_zfeatures.transpose(), method=method, metric='euclidean')

# plot the re-ordered cluster map
cbar_kws={'orientation': 'vertical',
          'label': 'feature Z-score',
          'extend': 'both',
          'extendrect':True
         }
arguments = {
    'row_cluster': True,
    'col_cluster': True,
    'row_linkage': patient_links,
    'col_linkage': feature_links
}
cmap = 'Spectral_r'
cg = sns.clustermap(np_zfeatures.transpose(), **arguments, cmap=cmap, vmin=-2, vmax=2, cbar_pos=(0.155,0.644,0.04, 0.15), cbar_kws=cbar_kws)
cg.ax_row_dendrogram.set_visible(False)
cg.ax_col_dendrogram.set_visible(True)
ax = cg.ax_heatmap
ax.set_xlabel('Patients', fontsize=16)
ax.set_ylabel('Radiomics Features', fontsize=16)
cb_ax = cg.ax_cbar.yaxis.set_ticks_position('left')
cb_ax = cg.ax_cbar.yaxis.set_label_position('left')

cg.savefig(f'hierarchical cluster map - method: {method}')

Solution

  • You will have to do the plot by hand, I don't think it's worth trying to hack around seaborn's ClusterGrid do get the result you need. You can generate the dendrograms using scipy, and plot the heatmap(s) using imshow()

    I can't take the time to code an exact replica, but here is a quick mock-up. Hopefully there's no mistake in there, but it is just a demonstration that it is feasible.

    import scipy
    # Random dummy data
    np.random.seed(1234)
    Npatients = 10
    Nfeatures = 20
    np_zfeatures = np.random.random((Npatients, Nfeatures)) # example matrix of z-scored features [patients, features]
    patient_T_stage = np.random.randint(low=1, high=5, size=(Npatients,))
    patient_N_stage = np.random.randint(low=0, high=4, size=(Npatients,))
    patient_M_stage = np.random.randint(low=0, high=2, size=(Npatients,))
    patient_O_stage = np.random.randint(low=0, high=5, size=(Npatients,))
    patient_subtype = np.random.randint(low=0, high=5, size=(Npatients,))
    feature_class = np.random.randint(low=0, high=5, size=(Nfeatures,))       # There's 5 categories of features (first order, shape, textural, wavelet, LoG)
    
    N_rows_patients = 5
    N_col_features = 1
    
    # HAC clustering (compute linkage matrices)
    method = 'ward'
    feature_links = scipy.cluster.hierarchy.linkage(np_zfeatures, method=method, metric='euclidean')
    patient_links = scipy.cluster.hierarchy.linkage(np_zfeatures.transpose(), method=method, metric='euclidean')
    
    
    fig = plt.figure()
    
    gs0 = matplotlib.gridspec.GridSpec(2,1, figure=fig,
                                       height_ratios=[8,2], hspace=0.05)
    gs1 = matplotlib.gridspec.GridSpecFromSubplotSpec(2,1, subplot_spec=gs0[0],
                                                      height_ratios=[2,8],
                                                      hspace=0)
    
    ax_heatmap = fig.add_subplot(gs1[1])
    ax_col_dendrogram = fig.add_subplot(gs1[0], sharex=ax_heatmap)
    
    col_dendrogram = scipy.cluster.hierarchy.dendrogram(feature_links, ax=ax_col_dendrogram)
    row_dendrogram = scipy.cluster.hierarchy.dendrogram(patient_links, no_plot=True)
    ax_col_dendrogram.set_axis_off()
    
    xind = col_dendrogram['leaves']
    yind = row_dendrogram['leaves']
    
    xmin,xmax = ax_col_dendrogram.get_xlim()
    data = pd.DataFrame(np_zfeatures)
    ax_heatmap.imshow(data.iloc[xind,yind].T, aspect='auto', extent=[xmin,xmax,0,1], cmap='Spectral_r', vmin=-2, vmax=2)
    ax_heatmap.yaxis.tick_right()
    plt.setp(ax_heatmap.get_xticklabels(), visible=False)
    
    gs2 = matplotlib.gridspec.GridSpecFromSubplotSpec(N_rows_patients, 1, subplot_spec=gs0[1])
    
    for i,(data,label) in enumerate(zip([patient_T_stage,patient_N_stage,patient_M_stage,patient_O_stage,patient_subtype],
                                        ['T-stage','N-stage','M-stage','Overall stage','Subtype'])):
        ax = fig.add_subplot(gs2[i], sharex=ax_heatmap)
        ax.imshow(np.vstack([data[xind],data[xind]]), aspect='auto', extent=[xmin,xmax,0,1], cmap='Blues')
        ax.set_yticks([])
        ax.set_ylabel(label, rotation=0, ha='right', va='center')
        if not ax.is_last_row():
            plt.setp(ax.get_xticklabels(), visible=False)
    

    enter image description here