Search code examples
pythonplotlyheatmapplotly-pythondendrogram

Plotly clustered heatmap (with dendrogram)/Python


I am trying to create a clustered heatmap (with a dendrogram) using plotly in Python. The one they have made in their website does not scale well, I have come to various solutions, but most of them are in R or JavaScript. I am trying to create a heatmap with a dendrogram from the left side of the heatmap only, showing clusters across the y axis (from the hierarchical clustering). A really good looking example is this one: https://chart-studio.plotly.com/~jackp/6748. My purpose is to create something like this, but only with the left-side dendrogram. If someone can implement something like this in Python, I will be really grateful!

Let the data be X = np.random.randint(0, 10, size=(120, 10))


Solution

  • The following suggestion draws on elements from both Dendrograms in Python and chart-studio.plotly.com/~jackp. This particular plot uses your data X = np.random.randint(0, 10, size=(120, 10)). One thing that the linked approaches had in common, was, in my opinion, that the datasets and data munging procedures were a bit messy. So I decided to build the following figure on a pandas dataframe with df = pd.DataFrame(X) to hopefully make everything a bit clearer

    Plot

    enter image description here

    Complete code

    import plotly.graph_objects as go
    import plotly.figure_factory as ff
    
    import numpy as np
    import pandas as pd
    from scipy.spatial.distance import pdist, squareform
    import random
    import string
    
    X = np.random.randint(0, 10, size=(120, 10))
    df = pd.DataFrame(X)
    
    # Initialize figure by creating upper dendrogram
    fig = ff.create_dendrogram(df.values, orientation='bottom')
    fig.for_each_trace(lambda trace: trace.update(visible=False))
    
    for i in range(len(fig['data'])):
        fig['data'][i]['yaxis'] = 'y2'
    
    # Create Side Dendrogram
    # dendro_side = ff.create_dendrogram(X, orientation='right', labels = labels)
    dendro_side = ff.create_dendrogram(X, orientation='right')
    for i in range(len(dendro_side['data'])):
        dendro_side['data'][i]['xaxis'] = 'x2'
    
    # Add Side Dendrogram Data to Figure
    for data in dendro_side['data']:
        fig.add_trace(data)
    
    # Create Heatmap
    dendro_leaves = dendro_side['layout']['yaxis']['ticktext']
    dendro_leaves = list(map(int, dendro_leaves))
    data_dist = pdist(df.values)
    heat_data = squareform(data_dist)
    heat_data = heat_data[dendro_leaves,:]
    heat_data = heat_data[:,dendro_leaves]
    
    heatmap = [
        go.Heatmap(
            x = dendro_leaves,
            y = dendro_leaves,
            z = heat_data,
            colorscale = 'Blues'
        )
    ]
    
    heatmap[0]['x'] = fig['layout']['xaxis']['tickvals']
    heatmap[0]['y'] = dendro_side['layout']['yaxis']['tickvals']
    
    # Add Heatmap Data to Figure
    for data in heatmap:
        fig.add_trace(data)
    
    # Edit Layout
    fig.update_layout({'width':800, 'height':800,
                             'showlegend':False, 'hovermode': 'closest',
                             })
    # Edit xaxis
    fig.update_layout(xaxis={'domain': [.15, 1],
                                      'mirror': False,
                                      'showgrid': False,
                                      'showline': False,
                                      'zeroline': False,
                                      'ticks':""})
    # Edit xaxis2
    fig.update_layout(xaxis2={'domain': [0, .15],
                                       'mirror': False,
                                       'showgrid': False,
                                       'showline': False,
                                       'zeroline': False,
                                       'showticklabels': False,
                                       'ticks':""})
    
    # Edit yaxis
    fig.update_layout(yaxis={'domain': [0, 1],
                                      'mirror': False,
                                      'showgrid': False,
                                      'showline': False,
                                      'zeroline': False,
                                      'showticklabels': False,
                                      'ticks': ""
                            })
    # # Edit yaxis2
    fig.update_layout(yaxis2={'domain':[.825, .975],
                                       'mirror': False,
                                       'showgrid': False,
                                       'showline': False,
                                       'zeroline': False,
                                       'showticklabels': False,
                                       'ticks':""})
    
    fig.update_layout(paper_bgcolor="rgba(0,0,0,0)",
                      plot_bgcolor="rgba(0,0,0,0)",
                      xaxis_tickfont = dict(color = 'rgba(0,0,0,0)'))
    
    fig.show()