Search code examples
pythonplotlyheatmap

how to jitter the scatter plot on px.imshow heatmap in python plotly


I have a risk matrix made using plotly's heatmap, and I have overlayed scatter points on the heatmap to show as each individual project. My question is if I have 2 projects that are in the same box, how to jitter the points so they can both fit in the box without overlaying onto each other? For example, I have project1 in the box of impact medium, likelihood low. If I have a project2 that's in the same box, how can both fit in the same box without overlaying?

enter image description here

import plotly.express as px
import plotly.graph_objects as go

fig = px.imshow([[3, 5, 5],
                 [1, 3, 5],
                 [1, 1, 3]],
                color_continuous_scale='Reds',
                labels=dict(x="Likelihood", y="Impact"),
                x=['Low', 'Medium', 'High'],
                y=['High', 'Medium', 'Low']
                )
fig.add_trace(go.Scatter(x=["Low"], y=["Medium"],
                         name="project1",
                         marker=dict(color='black', size=16)))
fig.show()

Expected Outcome

enter image description here


Solution

  • Using a graph object, create a graph with the xy-axis as the numeric axis, and then set the jitter value as a scatter with the xy-axis values. Finally, the xy-axis can be presented as a categorical variable by updating the xy-axis scale with a string. We take a guess and read the jitter of your expected graph and add it manually. Replace it with your data.

    import plotly.graph_objects as go
    
    fig = go.Figure()
    fig.add_trace(go.Heatmap(z=[[3, 5, 5],
                                [1, 3, 5],
                                [1, 1, 3]],
                             colorscale='Reds'))
    fig.add_trace(go.Scatter(x=[-0.25], y=[0.75], name="project1", marker=dict(color='black', size=16), showlegend=False))
    fig.add_trace(go.Scatter(x=[-0.25], y=[1.25], name="project1", marker=dict(color='black', size=16), showlegend=False))
    fig.add_trace(go.Scatter(x=[0.25], y=[0.75], name="project1", marker=dict(color='black', size=16), showlegend=False))
    fig.add_trace(go.Scatter(x=[0.25], y=[1.25], name="project1", marker=dict(color='black', size=16), showlegend=False))
    fig.add_trace(go.Scatter(x=[0], y=[1.0], name="project1", marker=dict(color='black', size=8), showlegend=False))
    
    fig.update_layout(height=400, width=400)
    fig.update_xaxes(tickvals=[0,1,2], ticktext=['Low', 'Medium', 'High'])
    fig.update_yaxes(tickvals=[0,1,2], ticktext=['High', 'Medium', 'Low'], autorange='reversed')
    fig.show()
    

    enter image description here