Search code examples
pythonplotly

Plotly: change line color on line crossings


I have a dataframe with two data series and a date. I want to plot it and change the line colors when the data series are crossing.

import plotly.graph_objects as go
from plotly.subplots import make_subplots

data = {'Date': ['1.1.2020', '2.1.2020', '3.1.2020', '4.1.2020', '5.1.2020'],
    'S1': [20, 21, 19, 18, 20],
    'S2': [10, 11, 30, 18, 10]}
df = pd.DataFrame(data)

df['color']= df.apply(lambda row: 'black' if row['S1'] >= row['S2'] else 'red', axis=1)

fig = make_subplots(rows=1, cols=1, 
                vertical_spacing=0.05,
                shared_xaxes=True)

fig.add_trace(go.Scatter(
     x=df['Date'],
     y=df['S1'],
     line=dict(color='blue')),
     row=1, col=1)

fig.add_trace(go.Scatter(
     x=df['Date'],
     y=df['S2'],
     line=dict(color=list(df['color']))),
     row=1, col=1)

fig.show()

This is how it should look: enter image description here Yet, the code does not work. Any hints?

Edit to increase solution space: It can also look like that enter image description here2

@Lemon: if I run your proposed solution using that data set

data = {'Date': ['1.1.2020', '2.1.2020', '3.1.2020', '4.1.2020', '5.1.2020', '6.1.2020', '7.1.2020', '8.1.2020', '9.1.2020', '10.1.2020'],
'S1': [20, 21, 19, 18, 20, 19, 18, 20, 20, 20],
'S2': [10, 11, 30, 18, 10, 10, 17, 30, 20, 10]}

I will get the following result. I have not figured out a way to remove the red line between the crossings. Any hints? enter image description here

Working Solution

based on lemons suggestion, here is the solution, that works for me

import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd

data = {'Date': ['1.1.2020', '2.1.2020', '3.1.2020', '4.1.2020', 
'5.1.2020', '6.1.2020', '7.1.2020', '8.1.2020', '9.1.2020', 
'10.1.2020'],
'S1': [20, 21, 19, 18, 20, 19, 18, 20, 20, 20],
'S2': [10, 11, 30, 18, 10, 10, 17, 30, 20, 10]}
df = pd.DataFrame(data)

df['crossing']= df.loc[(df.S1 < df.S2) | (df.S1.shift(1) < 
df.S2.shift(1)) | (df.S1.shift(-1) < df.S2.shift(-1)), 'S2']

fig = go.Figure()
fig.add_trace(go.Scatter(
  x=df['Date'],
  y=df['S1'],
  line=dict(color='blue')))

fig.add_trace(go.Scatter(
  x=df['Date'],
  y=df['S2'],
  line=dict(color='black')))

fig.add_trace(go.Scatter(
  x=df['Date'],
  y=df['crossing'],
  line=dict(color='red')))

fig.show()

Resulting chart

enter image description here


Solution

  • My initial idea followed the same r-beginners's idea, though what you can do to have a longer red line is detecting the rows when S2>S1 and a range of (-1, 1) centered on your "S2>S1" rows, using:

    df.loc[(df.S1 < df.S2) | 
           (df.S1.shift(1) < df.S2.shift(1)) | 
           (df.S1.shift(-1) < df.S2.shift(-1)), <the field you want to select>]
    

    Final code should look like:

    import plotly.graph_objects as go
    from plotly.subplots import make_subplots
    
    data = {'Date': ['1.1.2020', '2.1.2020', '3.1.2020', '4.1.2020', '5.1.2020'],
        'S1': [20, 21, 19, 18, 20],
        'S2': [10, 11, 30, 18, 10]}
    df = pd.DataFrame(data)
    
    df['color']= df.apply(lambda row: 'black' if row['S1'] >= row['S2'] else 'red', axis=1)
    
    fig = make_subplots(rows=1, cols=1, 
                    vertical_spacing=0.05,
                    shared_xaxes=True)
    
    fig.add_trace(go.Scatter(
         x=df['Date'],
         y=df['S1'],
         line=dict(color='blue')),
         row=1, col=1)
    
    fig.add_trace(go.Scatter(
         x=df['Date'],
         y=df['S2'],
         line=dict(color='black')),
         row=1, col=1)
    
    fig.add_trace(go.Scatter(
         x=df.loc[(df.S1 < df.S2) | (df.S1.shift(1) < df.S2.shift(1)) | (df.S1.shift(-1) < df.S2.shift(-1)), 'Date'],
         y=df.loc[(df.S1 < df.S2) | (df.S1.shift(1) < df.S2.shift(1)) | (df.S1.shift(-1) < df.S2.shift(-1)), 'S2'],
         line=dict(color='red')),
         row=1, col=1)
    
    fig.show()