Search code examples
pythonbokeh

Bokeh: legend with MultiIndex table


I just discovered Bokeh recently, and I try to display a legend for each day of week (represented by 'startdate_dayweek'). The legend should contain the color for each row corresponding to each day.

import pandas as pd
from bokeh.plotting import figure, show
from bokeh.io import output_file
from bokeh.palettes import Set1_7

output_file("conso_daily.html")

treatcriteria_data_global = pd.read_csv(r"treatcriteria_evolution.csv", sep=';')

final_global_data = treatcriteria_data_global.groupby(['startdate_weekyear','startdate_dayweek'],as_index = False).sum().pivot('startdate_weekyear','startdate_dayweek').fillna(0)

numlines = len(final_global_data.columns)
palette = Set1_7[0:numlines]

ts_list_of_list = []
for i in range(0,len(final_global_data.columns)):
    ts_list_of_list.append(final_global_data.index)

vals_list_of_list = final_global_data.values.T.tolist()

p = figure(width=500, height=300)
p.left[0].formatter.use_scientific = False
p.multi_line(ts_list_of_list, vals_list_of_list,
             legend='startdate_dayweek',
             line_color = palette,
             line_width=4)
show(p)

But I don't have the expected result in the legend:

enter image description here

How to have the legend for each day? Is the problem due to the fact that I created a MultiIndex table? Thanks.


Solution

  • The multi_line() function can take the parameter legend_field or legend_group. Both are working very well for your usecase, if you use a ColumnDataSource as source. Keep in mind, that a error will come if you use both parameters at the same time.

    Minimal Example

    from bokeh.plotting import figure, show, output_notebook
    from bokeh.models import ColumnDataSource
    output_notebook()
    
    source = ColumnDataSource(dict(
        xs=[[1,2,3,4,5],[1,2,3,4,5],[1,2,3,4,5]], 
        ys=[[1,2,3,4,5],[1,1,1,1,5],[5,4,3,2,1]],
        legend =['red', 'green', 'blue'],
        line_color = ['red', 'green', 'blue']))
    
    p = figure(width=500, height=300)
    p.multi_line(xs='xs', 
                 ys='ys',
                 legend_field ='legend',
                 line_color = 'line_color',
                 source=source,
                 line_width=4)
    show(p)
    

    Output

    Multiline with Legend