I am new to plotly and I am trying to create subplots and display color legend. My dataframe looks like this:
A B C D State
0 1 3 5 2 INITIAL
1 2 10 5 1 DONE
2 4 1 7 6 IN_PROGRESS
3 4 3 2 8 PAUSED
This is what I need as end result:
But I cannot make it work. I tried using plotly.express and then constructing the result figure by adding the components generated by it. This is the code:
fig = go.Figure()
figures = [
df.plot.scatter(x="A", y="B", color="State"),
df.plot.scatter(x="A", y="C", color="State")
]
fig = make_subplots(rows=len(figures), cols=1)
for i, figure in enumerate(figures):
for trace in range(len(figure["data"])):
fig.append_trace(figure["data"][trace], row=i+1, col=1)
fig.update_xaxes(title_text="A", row=2, col=1)
fig.update_xaxes(title_text="A", row=1, col=1)
fig.update_yaxes(title_text="C", row=2, col=1)
fig.update_yaxes(title_text="B", row=1, col=1)
fig.show()
Which is close to what I need besides States being repeated. Do you know how can I add the "State" header here and avoid repeated states?
Duplicate legends use set to remove duplicates from the graph configuration information and update the initial legend. Also, as complex customisation is not possible with express, it is easier and programmatically clearer to deal with this in the graph object. It is also possible to convert the data frame from horizontal to vertical format and loop through the same extraction conditions.
import plotly.graph_objects as go
from plotly.subplots import make_subplots
fig = make_subplots(rows=3, cols=1, shared_xaxes=True)
for i,g in enumerate(['B','C','D']):
for k,(s,c) in enumerate(zip(df['State'].unique(),
['red','black','blue','green'])):
fig.add_trace(go.Scatter(x=[df.loc[k,'A']],
y=[df.loc[k,g]],
mode='markers',
marker_color=c,
legendgroup=s,
name=s),
row=i+1, col=1)
# Remove duplicate legends
names = set()
fig.for_each_trace(
lambda trace:
trace.update(showlegend=False)
if (trace.name in names) else names.add(trace.name))
fig.update_xaxes(title_text="A", row=3, col=1, tickvals=[1,2,3,4])
fig.update_yaxes(title_text="D", row=3, col=1)
fig.update_yaxes(title_text="C", row=2, col=1)
fig.update_yaxes(title_text="B", row=1, col=1)
fig.show()