Search code examples
pythonplotlystreamlit

How to conditionally fill between two line charts with different colours using Plotly in Streamlit?


I'm trying to fill with colour the space between two line charts (between col1 and col2) as follows: Example plot

Desired output:

  • when col1 is above col2, fill with color green.
  • if col1 is under col2, fill with color red.

I have tried this:

def DisplayPlot(df):
   
    month = df['month'].tolist()
    col1 = df['col1'].tolist()
    col2 = df['col2'].tolist()
    col3 = df['col3'].tolist()

   
    fig.add_trace(go.Scatter(x=month, y=col1, name='col1', line=dict(color='blue', width=4)))
    fig.add_trace(go.Scatter(x=month, y=col2, name = 'col2',line= dict(color='red', width=2)))
    fig.add_trace(go.Scatter(x=month, y=col3, name='col3',mode='lines',line = dict(color='#CDCDCD', width=2, dash='dot')))

    # Edit the layout
    fig.update_layout(
            title=dict(
                text='Title'
            ),
            xaxis=dict(
                title=dict(
                    text='Month'
                )
            ),
            yaxis=dict(
                title=dict(
                    text='Taux'
                )
            ),
    )
    st.plotly_chart(fig)

But I have no idea how to add the colour based on the condition above.

Here's the dataset:

MONTH COL1 COL2 COL3
Jan 0.1555 0.1256 0.1863
Feb 0.1097 0.119 0.1863
Apr 0.1459 0.175 0.1863
Mar 0.2804 0.2634 0.1863
May 0.267 0.1855 0.1863

Solution

  • You'll have to understand that the line plot contains segments and you'll have to handle those segments separately to get the desired result.

    By segments I mean the regions between the lines of col1 and col2. The segments are separated from each other by the intersections of the lines.

    The following conditions are needed to be handled for each segment:

    • col1 is above col2 throughout the segment
    • col2 is above col1 throughout the segment
    • Handling the intersections during which the above conditions get reversed (We need to find these intersections using linear interpolation)
    • Handling parallel lines to avoid division by zero errors (refer to the code below for more clarity)

    Here's the DisplayPlot function to handle the above conditions:

    def DisplayPlot(df):
        if df["month"].dtype == "object":
            df["month"] = (
                pd.Categorical(
                    df["month"],
                    categories=["Jan", "Feb", "Mar", "Apr", "May", "Jun"],
                    ordered=True,
                ).codes
                + 1
            )
    
        month = df["month"]
        col1 = df["col1"]
        col2 = df["col2"]
        col3 = df["col3"]
    
        fig = go.Figure()
    
        fig.add_trace(
            go.Scatter(x=month, y=col1, name="col1", line=dict(color="blue", width=4))
        )
        fig.add_trace(
            go.Scatter(x=month, y=col2, name="col2", line=dict(color="red", width=2))
        )
        fig.add_trace(
            go.Scatter(
                x=month,
                y=col3,
                name="col3",
                mode="lines",
                line=dict(color="#CDCDCD", width=2, dash="dot"),
            )
        )
    
        # Iterate through segments and fill regions conditionally
        for i in range(1, len(month)):
            x1, x2 = month[i - 1], month[i]
            y1_col1, y2_col1 = col1[i - 1], col1[i]
            y1_col2, y2_col2 = col2[i - 1], col2[i]
    
            # Check if col1 is above col2 throughout the segment
            if y1_col1 > y1_col2 and y2_col1 > y2_col2:
                fig.add_trace(
                    go.Scatter(
                        x=[x1, x2, x2, x1],
                        y=[y1_col1, y2_col1, y2_col2, y1_col2],
                        fill="toself",
                        fillcolor="rgba(0, 255, 0, 0.3)",  # Green fill
                        mode="none",
                        showlegend=False,
                    )
                )
            # Check if col2 is above col1 throughout the segment
            elif y1_col1 < y1_col2 and y2_col1 < y2_col2:
                fig.add_trace(
                    go.Scatter(
                        x=[x1, x2, x2, x1],
                        y=[y1_col2, y2_col2, y2_col1, y1_col1],
                        fill="toself",
                        fillcolor="rgba(255, 0, 0, 0.3)",  # Red fill
                        mode="none",
                        showlegend=False,
                    )
                )
            else:  # Handle crossing lines within the segment
                # Avoid division by zero by checking if the difference is non-zero
                denominator = (y2_col1 - y1_col1) - (y2_col2 - y1_col2)
                if denominator != 0:
                    # Find the intersection point using linear interpolation
                    intersect_x = x1 + (x2 - x1) * ((y1_col2 - y1_col1) / denominator)
                    intersect_y = y1_col1 + (intersect_x - x1) * (
                        (y2_col1 - y1_col1) / (x2 - x1)
                    )
    
                    # Fill green where col1 is above
                    if y1_col1 > y1_col2:
                        fig.add_trace(
                            go.Scatter(
                                x=[x1, intersect_x, intersect_x, x1],
                                y=[y1_col1, intersect_y, intersect_y, y1_col2],
                                fill="toself",
                                fillcolor="rgba(0, 255, 0, 0.3)",
                                mode="none",
                                showlegend=False,
                            )
                        )
    
                        # Fill red where col2 is above
                        fig.add_trace(
                            go.Scatter(
                                x=[intersect_x, x2, x2, intersect_x],
                                y=[intersect_y, y2_col1, y2_col2, intersect_y],
                                fill="toself",
                                fillcolor="rgba(255, 0, 0, 0.3)",
                                mode="none",
                                showlegend=False,
                            )
                        )
                    else:
                        # Fill red where col2 is above
                        fig.add_trace(
                            go.Scatter(
                                x=[x1, intersect_x, intersect_x, x1],
                                y=[y1_col2, intersect_y, intersect_y, y1_col1],
                                fill="toself",
                                fillcolor="rgba(255, 0, 0, 0.3)",
                                mode="none",
                                showlegend=False,
                            )
                        )
    
                        # Fill green where col1 is above
                        fig.add_trace(
                            go.Scatter(
                                x=[intersect_x, x2, x2, intersect_x],
                                y=[intersect_y, y2_col2, y2_col1, intersect_y],
                                fill="toself",
                                fillcolor="rgba(0, 255, 0, 0.3)",
                                mode="none",
                                showlegend=False,
                            )
                        )
                else:
                    # If lines are parallel, fill based on initial conditions
                    if y1_col1 > y1_col2:
                        fig.add_trace(
                            go.Scatter(
                                x=[x1, x2, x2, x1],
                                y=[y1_col1, y2_col1, y2_col2, y1_col2],
                                fill="toself",
                                fillcolor="rgba(0, 255, 0, 0.3)",
                                mode="none",
                                showlegend=False,
                            )
                        )
                    else:
                        fig.add_trace(
                            go.Scatter(
                                x=[x1, x2, x2, x1],
                                y=[y1_col2, y2_col2, y2_col1, y1_col1],
                                fill="toself",
                                fillcolor="rgba(255, 0, 0, 0.3)",
                                mode="none",
                                showlegend=False,
                            )
                        )
    
        month_mapping = {1: "Jan", 2: "Feb", 3: "Mar", 4: "Apr", 5: "May", 6: "Jun"}
    
        fig.update_layout(
            title=dict(text="Comparison of col1 and col2"),
            xaxis=dict(
                title=dict(text="Month"),
                tickvals=list(month_mapping.keys()),
                ticktext=list(month_mapping.values()),
            ),
            yaxis=dict(title=dict(text="Values")),
            showlegend=True,
        )
    
        st.plotly_chart(fig)
    

    You can find a full working example here with the corresponding Streamlit application deployed here.

    Update1: As you have shared the months are strings and not numbers - you can just add this at the start of the function and it should work (updated the code above as well):

    def DisplayPlot(df):
        if df["month"].dtype == "object":
            df["month"] = (
                pd.Categorical(
                    df["month"],
                    categories=["Jan", "Feb", "Mar", "Apr", "May", "Jun"],
                    ordered=True,
                ).codes
                + 1
            )
    .
    .
    .
    

    Also, you can have more months in the categories as per your requirement. I have kept it to the minimum for brevity.

    Update2: We can use a mapping dictionary to replace the numbers with month names during the final layout update as follows (updated the original code as well):

    .
    .
    .
        month_mapping = {1: "Jan", 2: "Feb", 3: "Mar", 4: "Apr", 5: "May", 6: "Jun"}
    
        fig.update_layout(
            title=dict(text="Comparison of col1 and col2"),
            xaxis=dict(
                title=dict(text="Month"),
                tickvals=list(month_mapping.keys()),
                ticktext=list(month_mapping.values()),
            ),
            yaxis=dict(title=dict(text="Values")),
            showlegend=True,
        )
        
        st.plotly_chart(fig)
    .
    .
    .