Search code examples
python-3.xplotlyplotly-python

Plotly python facetted heatmaps


I'm using the example from this SO Q&A to use seaborn for facetted heatmaps in python. The result looks like this:

enter image description here

I'd like to do the same thing with plotly express and have tried with this starter code:

import plotly.express as px
df = px.data.medals_wide(indexed=True)
fig = px.imshow(df)
fig.show()

My data is also in a pd.DataFrame and it's important I show the groups the heatmaps are grouped by as well as the x/y-axis of the maps.

How do you extend the px.imshow example to create a facetted heatmap by group like the seaborn example above?


Solution

  • The sample data is taken from the referenced responses to answer the question. express, as data, can be subplotted if it is column data, but the results cannot be used with a categorical variable as the extraction condition with a different categorical variable, as in the sample data. You can draw it if it is as a subplot using a graph object in A heat map can be created by specifying the xy-axis in the data frame of the result of data extraction by category variable.

    import numpy as np
    import pandas as pd
    import plotly.express
    
    # Generate a set of sample data
    np.random.seed(0)
    indices = pd.MultiIndex.from_product((range(5), range(5), range(5)), names=('label0', 'label1', 'label2'))
    data = pd.DataFrame(np.random.uniform(0, 100, size=len(indices)), index=indices, columns=('value',)).reset_index()
    
    import plotly.graph_objects as go
    from plotly.subplots import make_subplots
    
    titles = ['label0='+ str(x) for x in range(5)]
    fig = make_subplots(rows=1, cols=len(data['label0'].unique()),
                        shared_yaxes=True,
                        subplot_titles = tuple(titles))
    
    for i in data['label0'].unique():
        df = data[data['label0'] == i]
        fig.add_trace(go.Heatmap(z=df.value, x=df.label1, y=df.label2), row=1, col=i+1)
        fig.update_traces(showscale=False)
        fig.update_xaxes(dtick=[0,1,2,3,4])
        fig.update_xaxes(title_text='label1', row=1, col=i+1)
        fig.update_yaxes(title_text='label2', row=1, col=1)
    
    fig.show()
    

    enter image description here