Search code examples
pythonplotly

How can I create a tile plot using plotly, where each tile has a text label and a variable color intensity?


Here's my data

import pandas as pd

df = pd.DataFrame({
    'row': ['A','A','A','B','B','B'],
    'col': [1,2,3,1,2,3],
    'val': [0.1,0.9,0.2,0.5,0.2,0.2],
    'animal': ['duck', 'squirrel', 'horse', 'cow', 'pig', 'cat']
})


df
  row  col  val    animal
0   A    1  0.1      duck
1   A    2  0.9  squirrel
2   A    3  0.2     horse
3   B    1  0.5       cow
4   B    2  0.2       pig
5   B    3  0.2       cat

I would like to make a plot like this

enter image description here

But the closest I can get (using imshow) is this

enter image description here

Note: imshow doesn't provide a border between tiles (as far as I'm aware) which is also important to me.

My attempt

import pandas as pd
import plotly.express as px

df = pd.DataFrame({
    'row': ['A','A','A','B','B','B'],
    'col': [1,2,3,1,2,3],
    'val': [0.1,0.9,0.2,0.5,0.2,0.2],
    'animal': ['duck', 'squirrel', 'horse', 'cow', 'pig', 'cat']
})

df_wide = pd.pivot_table(
    data=df,
    index='row',
    columns='col',
    values='val',
    aggfunc='first'
)

fig = px.imshow(
    img=df_wide,
    zmin=0,
    zmax=1,
    origin="upper",
    color_continuous_scale='gray'
)
fig.update_layout(coloraxis_showscale=False)
fig.show()

Solution

  • You can try this:

    import pandas as pd
    import plotly.graph_objects as go
    
    df = pd.DataFrame({
        'row': ['A','A','A','B','B','B'],
        'col': [1,2,3,1,2,3],
        'val': [0.1,0.9,0.2,0.5,0.2,0.2],
        'animal': ['duck', 'squirrel', 'horse', 'cow', 'pig', 'cat']
    })
    
    df_wide = pd.pivot_table(
        data=df,
        index='row',
        columns='col',
        values='val',
        aggfunc='first'
    )
    

    Then, I use heatmaps which have more control over the gaps between tiles:

    fig = go.Figure(data=go.Heatmap(
                        z=df_wide.values,
                        x=df.col.unique(),
                        y=df.row.unique(),
                        text=df['animal'].values.reshape(df_wide.values.shape),
                        texttemplate="%{text}",
                        textfont={"size":15,"color":'red'},
                        colorscale='gray',
                        showscale=False,
                        ygap = 3, 
                        xgap = 3)
                   )
    
    fig.show()
    

    enter image description here

    I use ygap = 3, xgap = 3 to add gaps between tiles.