Search code examples
pythonplotlyplotly-python

how to plot a single line in plotly with multiple colors according to a categorical variable


How can I get a single connected line in plotly with different colors?

The plot below shows an attempt at a solution. However, the line has an ugly break between point 10 and point 90. How can I have a single line with multiple colors according to a categorical variable without breaking?enter image description here

import numpy as np
import random
import pandas as pd
import plotly.express as px

white_noise = np.array([random.gauss(mu=0.0, sigma=1.0) for x in range(100)])
rw = white_noise.cumsum()
rw_df = pd.DataFrame({
    'random_walk': rw, 
    'color': 10*['black'] + 50 * ['blue'] + 40*['black'], 
    'x': range(100)
})```

fig = px.line(rw_df, x='x', y='random_walk', color='color')
fig.show()

Solution

  • See if this is what you are looking for...

    First - break the dataframe into an array of dataframes, each with a set of same colored rows... in your example, it would be 3 dataframes - first for the black line, second for blue and then another for the black.

    Second - Plot the first item as a line. Then for each subsequent dataframes, use add scatter to plot a new line.

    import numpy as np
    import random
    import pandas as pd
    import plotly.express as px
    
    white_noise = np.array([random.gauss(mu=0.0, sigma=1.0) for x in range(100)])
    rw = white_noise.cumsum()
    rw_df = pd.DataFrame({
        'random_walk': rw, 
        'color': 10*['black'] + 50 * ['blue'] + 40*['black'], 
        'x': range(100)
    })
    
    ## Break your dataframe into an array of smaller dataframes, each with single color
    rw_df['group']=rw_df['color'].ne(rw_df['color'].shift()).cumsum()
    rw_df
    rw_df = rw_df.groupby('group')
    dfs = []
    for name, data in rw_df:
        dfs.append(data)
    
    ## Plot the first line - dfs[0]    
    fig = px.line(dfs[0], x='x', y='random_walk')
    fig.update_traces(line_color=dfs[0]['color'].iloc[0])
    
    ## For other lines in array, plot new line
    for i in range(1, len(dfs)):
        fig.add_scatter(x=dfs[i]["x"], y=dfs[i]["random_walk"], 
                        line_color=dfs[i]['color'].iloc[0], 
                        name=dfs[i]['color'].iloc[0])
        
    fig.show()
    

    enter image description here