Search code examples
pythonplotlydata-visualizationplotly-python

Plotly Strip plot - color by continuous scale


I would like to use plotly.express.strip to show the distribution of my data points. The catch is that I would also like to color each data point according to its value. As an example, taken from Plotly's documentation:

import plotly.express as px

df = px.data.tips()
fig = px.strip(df, x="total_bill", y="day")
fig.show()

This shows: Plotly strip example But I would like the points to be colored by the "total_bill" amount using a continuous color scale.

From my research I haven't found a way to do this using this specific plot (i.e. I can't use a px.scatter).

Is there any way to do this? Thanks


Solution

  • You need a jitter. I applied a pseudo-jitter function to the y-axis to create a scatter plot. The jitter function was inspired by the contents of this page. It may differ from the original graphing logic. At first, I wrote the graph by extracting the days of the week in a loop process, but since the color scale duplicates, I described all the days of the week and added the color scale only for Saturday since Saturday contains the largest value.

    import plotly.graph_objects as go
    import plotly.express as px
    import numpy as np
    
    df = px.data.tips()
    N = len(df)
    
    fig = go.Figure()
    
    dfs = df.query('day == "Sun"')
    fig.add_trace(go.Scatter(
        x=dfs['total_bill'], 
        y=0 + np.random.rand(N) * 0.2,
        mode='markers',
        marker=dict(
            size=9,
            color=dfs['total_bill'],
        ),
        name='Sun',
    ))
    
    dfst = df.query('day == "Sat"')
    fig.add_trace(go.Scatter(
        x=dfst['total_bill'], 
        y=1 + np.random.rand(N) * 0.2,
        mode='markers',
        marker=dict(
            size=9,
            color=dfst['total_bill'],
            colorbar=dict(
                title='total_bill',
            ),
            colorscale='Plasma'
        ),
        name='Sat',
    ))
    
    dfth = df.query('day == "Thur"')
    fig.add_trace(go.Scatter(
        x=dfth['total_bill'], 
        y=2 + np.random.rand(N) * 0.2,
        mode='markers',
        marker=dict(
            size=9,
            color=dfth['total_bill'],
        ),
        name='Thur',
    ))
    
    dff = df.query('day == "Fri"')
    fig.add_trace(go.Scatter(
        x=dff['total_bill'], 
        y=3 + np.random.rand(N) * 0.2,
        mode='markers',
        marker=dict(
            size=9,
            color=dff['total_bill'],
        ),
        name='Fri',
    ))
    
    fig.update_layout(showlegend=False, coloraxis_showscale=True)
    fig.update_yaxes(tickvals=[0,1,2,3], ticktext=['Sun','Sat','Thur','Fri'])
    fig.show() 
    

    enter image description here