I would like to use plotly.express.strip
to show the distribution of my data points. The catch is that I would also like to color each data point according to its value.
As an example, taken from Plotly's documentation:
import plotly.express as px
df = px.data.tips()
fig = px.strip(df, x="total_bill", y="day")
fig.show()
This shows: But I would like the points to be colored by the "total_bill" amount using a continuous color scale.
From my research I haven't found a way to do this using this specific plot (i.e. I can't use a px.scatter).
Is there any way to do this? Thanks
You need a jitter. I applied a pseudo-jitter function to the y-axis to create a scatter plot. The jitter function was inspired by the contents of this page. It may differ from the original graphing logic. At first, I wrote the graph by extracting the days of the week in a loop process, but since the color scale duplicates, I described all the days of the week and added the color scale only for Saturday since Saturday contains the largest value.
import plotly.graph_objects as go
import plotly.express as px
import numpy as np
df = px.data.tips()
N = len(df)
fig = go.Figure()
dfs = df.query('day == "Sun"')
fig.add_trace(go.Scatter(
x=dfs['total_bill'],
y=0 + np.random.rand(N) * 0.2,
mode='markers',
marker=dict(
size=9,
color=dfs['total_bill'],
),
name='Sun',
))
dfst = df.query('day == "Sat"')
fig.add_trace(go.Scatter(
x=dfst['total_bill'],
y=1 + np.random.rand(N) * 0.2,
mode='markers',
marker=dict(
size=9,
color=dfst['total_bill'],
colorbar=dict(
title='total_bill',
),
colorscale='Plasma'
),
name='Sat',
))
dfth = df.query('day == "Thur"')
fig.add_trace(go.Scatter(
x=dfth['total_bill'],
y=2 + np.random.rand(N) * 0.2,
mode='markers',
marker=dict(
size=9,
color=dfth['total_bill'],
),
name='Thur',
))
dff = df.query('day == "Fri"')
fig.add_trace(go.Scatter(
x=dff['total_bill'],
y=3 + np.random.rand(N) * 0.2,
mode='markers',
marker=dict(
size=9,
color=dff['total_bill'],
),
name='Fri',
))
fig.update_layout(showlegend=False, coloraxis_showscale=True)
fig.update_yaxes(tickvals=[0,1,2,3], ticktext=['Sun','Sat','Thur','Fri'])
fig.show()