Search code examples
javascriptpythonplotbokehinteractive

Filter data with Javascript callback in Python's Bokeh


apologies in advance for unprecise/unappreciated wording as this is my first question here. Feel free to point out how I can improve it in the future.

I have been reading through all of Bokeh's user guide and various forums but belief this question is still insufficiently covered as it appears over and over again without an answer that can be applied generically.

My task is to construct a scatterplot in Python's Bokeh that can interactively be filtered based on a categorical variable. My limited understanding of Javascript (and how the data is structured) prevents me from figuring this out by myself.

I found, that one solution is to append x&y values that fulfill the condition (f.e. Filtering Bokeh LabelSet with Javascript). However, I want to keep all the other variables as well, since I use them to define graphic parameters / hover information in the plot.

Therefore my question, how can I append whole rows to the new output data if one of the columns fulfills a certain condition in Javascript? I am also unsure if I call the callback correctly such that the plot would actually react to my selection. So please fell free to point out any mistakes here as well.

See some example code here:

#Packages
import pandas as pd
import numpy as np
from bokeh.plotting import figure, output_file, show
import bokeh.events as bev
import bokeh.models as bmo
import bokeh.layouts as bla

#Data
data = pd.DataFrame(data = np.array([[1,1,'a',0.5],
                                     [2,2,'a',0.5],
                                     [3,3,'a',0.75],
                                     [4,4,'b',1],
                                     [5,5,'b',2]]),
                    columns = ['x', 'y', 'category', 'other information'])


#Setup
output_file('dashboard.html')

source = bmo.ColumnDataSource(data)

#Define dropdown options
dropdown_options = [('All', 'item_1'), None] + [(cat, str('item_' + str(i))) for i, cat in enumerate(sorted(data['category'].unique()), 2)]

#Generate dropdown widget
dropdown = bmo.Dropdown(label = 'Category', button_type = 'default', menu = dropdown_options)


#Callback
callback = bmo.CustomJS(args = dict(source = source),
                        code = """
                        
                        var data = source.data;
                        
                        var cat = cb_obj.value;
                        
                        if (cat = 'All'){
                                
                            data = source.data
                                
                        } else {
                            
                            var new_data = [];
                            
                            for (cat i = 0; i <= source.data['category'].length; i++){
                                    
                                    if (source.data['category'][i] == cat) {
                                            
                                            new_data.push(source.data[][i])
                                            
                                            }
                                    
                                    }
                            
                            data = new_data.data
                                                    
                        }
                            
                        source.data = data
                                                  
                        source.change.emit();
                        
                        """)


#Link actions
dropdown.js_on_event(bev.MenuItemClick, callback)

#Plot
p = figure(plot_width = 800, plot_height = 530, title = None)

p.scatter(x = 'x', y = 'y', source = source)


show(bla.column(dropdown, p))

Unsurprisingly, the filter does not work. As said, any help highly appreciated since I do not know how to index whole rows in Javascript and whatever else I am doing wrong.

Best regards, Oliver


Solution

  • I wrote a solution for your issue. I am no Bokeh expert so I might not know everything but hope that helps to understand what is going on. Some explanation:

    • You had some syntax errors to start with: at your for loop you used cat i, you probably meant var i

    • In your if you were assigning All to cat, you need to do the comparison: with either cat == 'All' or cat === 'All'

    • your cb_obj.value did not work for some reason and was returning undefined. You can check your variables with simple console.log(variableName) and open dev console in the browser to see callbacks in action. I changed your list comprehension to be tuple of the same values instead of (category_name, item_category_number). Now cb_obj.item returns category_name which you can do comparison with.

    • You should understand what format your data is in, you can do so with console.log(source.data) for example. source.data here is object of arrays (or dictionary of lists if you were to describe that in Python). Because of that you could not push the data the way you did in for loop and also you had a syntax error: source.data[][i] - you won't access what you want with empty bracket. I wrote two functions to handle this functionality. generateNewDataObject creates object of empty arrays that we can append with addRowToAccumulator

    • The last thing is that I needed were two data_sources. First that we will not do changes on and second that we will modify and use to display on the plot. If we were to modify the first one then after the first filter all other categories would be dropped and we could get them back only by refreshing the page. The 'immutable' data_source allows us to reference it and not lose filtered data in the process.

    I hope that helps.

    # Packages
    
    import bokeh.events as bev
    import bokeh.layouts as bla
    import bokeh.models as bmo
    import numpy as np
    import pandas as pd
    from bokeh.plotting import figure, output_file, show
    
    # Data
    data = pd.DataFrame(
        data=np.array(
            [
                [1, 1, 'a', 0.5],
                [2, 2, 'a', 0.5],
                [3, 3, 'a', 0.75],
                [4, 4, 'b', 1],
                [5, 5, 'b', 2]
            ]
        ),
        columns=['x', 'y', 'category', 'other information']
    )
    
    # Setup
    output_file('dashboard.html')
    
    source = bmo.ColumnDataSource(data)
    
    # Define dropdown options
    dropdown_options = [
                           ('All', 'All'), None
                       ] + [(cat, cat)
                           for i, cat in enumerate(sorted(data['category'].unique()), 2)
                       ]
    # Generate dropdown widget
    dropdown = bmo.Dropdown(label='Category', button_type='default', menu=dropdown_options)
    
    filtered_data = bmo.ColumnDataSource(data)
    # Callback
    callback = bmo.CustomJS(
        args=dict(unfiltered_data=source, filtered_data=filtered_data),
        code="""
    
    var data = unfiltered_data.data;
    var cat = cb_obj.item;
    
    function generateNewDataObject(oldDataObject){
        var newDataObject = {}
        for (var key of Object.keys(oldDataObject)){
            newDataObject[key] = [];
        }
        return newDataObject
    
    }
    
    function addRowToAccumulator(accumulator, dataObject, index) {
        for (var key of Object.keys(dataObject)){
            accumulator[key][index] = dataObject[key][index];
        }
        return accumulator;
    }
    
    if (cat === 'All'){
        data = unfiltered_data.data;
    } else {
        var new_data =  generateNewDataObject(data);
        for (var i = 0; i <= unfiltered_data.data['category'].length; i++){
            if (unfiltered_data.data['category'][i] == cat) {
                new_data = addRowToAccumulator(new_data, unfiltered_data.data, i);
            }
        }
        data = new_data;
    }
    
    filtered_data.data = data;
    filtered_data.change.emit();
    """
    )
    
    # Link actions
    dropdown.js_on_event(bev.MenuItemClick, callback)
    
    # Plot
    p1 = figure(plot_width=800, plot_height=530, title=None)
    
    p1.scatter(x='x', y='y', source=filtered_data)
    
    show(bla.column(dropdown, p1))