Search code examples
pythonplotly

How to set the color of a plotly line based on a number


I have the following data frame:

import pandas as pd
import numpy as np

df = pd.DataFrame(
    data=np.cumsum( np.sqrt(1 / 1000) * np.random.normal(size=(1000, 10)), axis=0),
    columns=np.array([*range(1, 11)]))

and I want to plot it using plotly express and have the color of each plot to be based on the curves column value and I want the color to continuously evolve. So for example plots 1-3 ish could be yellow, then 4-7 could be orange and 8-10 could be red.

I tried

import plotly.express as px
fig = px.line(df, x=df.index, y=df.columns, colors=df.columns)
fig.show()

but I received the error

All arguments should have the same length. The length of the argument 'color' is 10 where as the length of the previously-processed arguments ['index', '1', '2', '3', '4',..., '10'] is 1000

Essentially what I want to do is translate the following matplot lib code into plotly

import matplotlib.pyplot as plt
import matplotlib.colors as mplc

cmap = plt.get_cmap('viridis', df.shape[1])
norm = mplc.Normalize(vmin=1, vmax=10)
for i, l in enumerate(df.T.values):
    plt.plot(l, color=cmap(norm(i)))
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
plt.colorbar(sm, label='signal number')
plt.show()

which returns the following: enter image description here


Solution

  • Here is the code that I needed. I updated the graph to lines to illustrate that this code will not repeat colors.

      import plotly.graph_objects as go
      import plotly.express as px
      import pandas as pd
      import numpy as np
      import matplotlib.pyplot as plt
      import matplotlib.colors as mplc
      import numpy as np
    
      no_sigs = 100
    
      line = lambda x, b: x + b
      lines = np.array(
              [line(time, b) for b in np.linspace(0, 1, 100)]
              )
      line_names = np.array([*range(1, no_sigs + 1)])
      lines_df = pd.DataFrame(
              data=lines.T,
              index=time,
              columns=line_names)
    
      # rgba colors and plot
      cmap = plt.get_cmap('viridis', no_sigs)
      norm = mplc.Normalize(vmin=sig_names[0], vmax=sig_names[-1])
    
      # Initialize the color bar
      c_bar_data = np.vstack([lines_df.columns.values,
                              np.empty((2, int(no_sigs)))]).T
      colorbar_df = pd.DataFrame(
              data=c_bar_data,
              columns=['sig_num', 'empt1', 'empt2'])
      fig = px.scatter(colorbar_df,
                       x='empt1',
                       y='empt2',
                       color='sig_num',
                       color_continuous_scale=px.colors.sequential.Viridis)
    
      # Plot the lines
      for c in lines_df.columns:
          _ = fig.add_trace(
                  go.Scatter(
                      y=lines_df[c].values,
                      line={'color': f'rgba{cmap(norm(c))}'},
                      showlegend=False)
                  )
      _ = fig.update_xaxes(
              title={'text': 'x-axis'})
      _ = fig.update_yaxes(
              title={'text': 'y-axis'})
      fig.show()
    

    Here is the output:enter image description here