Search code examples
pythonplotlyplotly-python

Plotly subplots - color legend for multicoloured subplots


I am new to plotly and I am trying to create subplots and display color legend. My dataframe looks like this:

   A   B  C  D        State  
0  1   3  5  2      INITIAL    
1  2  10  5  1         DONE  
2  4   1  7  6  IN_PROGRESS  
3  4   3  2  8       PAUSED       

This is what I need as end result: ![enter image description here

But I cannot make it work. I tried using plotly.express and then constructing the result figure by adding the components generated by it. This is the code:

fig = go.Figure()

figures = [
        df.plot.scatter(x="A", y="B", color="State"),
        df.plot.scatter(x="A", y="C", color="State")
    ]

fig = make_subplots(rows=len(figures), cols=1) 

for i, figure in enumerate(figures):
    for trace in range(len(figure["data"])):
        fig.append_trace(figure["data"][trace], row=i+1, col=1)

fig.update_xaxes(title_text="A", row=2, col=1)  
fig.update_xaxes(title_text="A", row=1, col=1)       
fig.update_yaxes(title_text="C", row=2, col=1)
fig.update_yaxes(title_text="B", row=1, col=1)
fig.show()

This is the result: enter image description here

Which is close to what I need besides States being repeated. Do you know how can I add the "State" header here and avoid repeated states?


Solution

  • Duplicate legends use set to remove duplicates from the graph configuration information and update the initial legend. Also, as complex customisation is not possible with express, it is easier and programmatically clearer to deal with this in the graph object. It is also possible to convert the data frame from horizontal to vertical format and loop through the same extraction conditions.

    import plotly.graph_objects as go
    from plotly.subplots import make_subplots
    
    fig = make_subplots(rows=3, cols=1, shared_xaxes=True) 
    
    for i,g in enumerate(['B','C','D']):
        for k,(s,c) in enumerate(zip(df['State'].unique(),
                                     ['red','black','blue','green'])):
            fig.add_trace(go.Scatter(x=[df.loc[k,'A']],
                                     y=[df.loc[k,g]],
                                     mode='markers',
                                     marker_color=c,
                                     legendgroup=s,
                                     name=s),
                          row=i+1, col=1)
    
    # Remove duplicate legends
    names = set()
    fig.for_each_trace(
        lambda trace:
            trace.update(showlegend=False)
            if (trace.name in names) else names.add(trace.name))
    
    fig.update_xaxes(title_text="A", row=3, col=1, tickvals=[1,2,3,4])  
    fig.update_yaxes(title_text="D", row=3, col=1)       
    fig.update_yaxes(title_text="C", row=2, col=1)
    fig.update_yaxes(title_text="B", row=1, col=1)
    
    fig.show()
    

    enter image description here