Search code examples
pythonplotly-pythonreactivepy-shiny

How do I filter a table by clicking on a bar chart segment?


I built an RShiny dashboard with plotly-r using the palmer penguins dataset so that when I click on a bar chart segment, it uses that event data to filter a dataset.

On Hover:
enter image description here

On Click:
enter image description here

I wanted to build a similar dashboard using Shiny for Python, but haven't been able to get the customdata to work that way it did in R and I cannot figure out how to handle a click event.

I've been able to the establish the customdata object to capture the categorical data I want for each bar segment, but I wasn't able to find an event_click() function in plotly for python similar to that of plotly-r.

Can someone please tell me how I can get this functionality working? Presently this code runs the shiny for python dashboard, but doesn't respond to any sort of click events anywhere on the bar chart.

I think I probably need help in the @render_widget part of the server function. Here is my Python Source Code:

# Load data and compute static values
from shiny import App, reactive, render, ui
from shinywidgets import output_widget, render_widget, render_plotly
from plotnine import ggplot, aes, geom_bar
import plotly.graph_objects as go
import palmerpenguins
from plotly.callbacks import Points, InputDeviceState
points, state = Points(), InputDeviceState()


df_penguins = palmerpenguins.load_penguins()

dict_category = {'species':'Species','island':'Island','sex':'Gender'}

def filter_shelf():
    return ui.card(
        ui.card_header(
            "Filters", 
            align="center",
        ),
        # Gender Filter
        ui.input_checkbox_group(
            'sex_filter', 
            label='Gender', 
            choices={value:(value.capitalize() if (type(value)==str) else value) for value in df_penguins['sex'].unique()},
            selected=list(df_penguins['sex'].unique()),
        ),

        # Species Filter
        ui.input_checkbox_group(
            'species_filter', 
            label='Species', 
            choices=list(df_penguins['species'].unique()),
            selected=list(df_penguins['species'].unique()),
        ),

        # Island Filter
        ui.input_checkbox_group(
            'island_filter', 
            label='Island', 
            choices=list(df_penguins['island'].unique()),
            selected=list(df_penguins['island'].unique()),
        ),
    )

def parameter_shelf():
    return ui.card(
        ui.card_header(
            'Parameters',
            align='center',
        ),

        # Category Selector
        ui.input_radio_buttons(
            'category', 
            label = 'View Penguins by:', 
            choices=dict_category,
            selected = 'species',
        ),
    ),

app_ui = ui.page_fluid(
    ui.h1("Palmer Penguins Analysis"),
    ui.layout_sidebar(
        # Left Sidebar
        ui.sidebar(
            filter_shelf(),
            parameter_shelf(),
            width=250,
        ),
        
        # Main Panel
        ui.card( # Plot
            ui.card_header(ui.output_text('chart_title')),
            output_widget('penguin_plot'),
        ),
        ui.card( # Table
            ui.card_header(ui.output_text('total_rows')),
            ui.column(
                12, #width
                ui.output_table('table_view'),
                style="height:300px; overflow-y: scroll"
            )
        )
    ),
)

def server (input, output, session):
    
    @reactive.calc
    def category():
        '''This function caches the appropriate Capitalized form of the selected category'''
        return dict_category[input.category()]

    # Dynamic Chart Title
    @render.text
    def chart_title():
        return "Number of Palmer Penguins by Year, colored by "+category()


    @reactive.calc
    def df_filtered_stage1():
        '''This function caches the filtered datframe based on selections in the view'''
        return df_penguins[
            (df_penguins['species'].isin(input.species_filter())) &
            (df_penguins['island'].isin(input.island_filter())) &
            (df_penguins['sex'].isin(input.sex_filter()))]

    @reactive.calc
    def df_filtered_stage2():
        df_filtered_st2 = df_filtered_stage1() 
        # Eventually add additional filters on dataset from segments selected on the visual
        return df_filtered_st2 
    
    @reactive.calc
    def df_summarized():
        return df_filtered_stage2().groupby(['year',input.category()], as_index=False).count().rename({'body_mass_g':"count"},axis=1)[['year',input.category(),'count']]

    @reactive.calc
    def filter_fn():
        print("Clicked!") # This never gets called
        
    @render_widget
    def penguin_plot():
        df_plot = df_summarized()
        bar_columns = list(df_plot['year'].unique()) # x axis column labels
        bar_segments = list(df_plot[input.category()].unique()) # bar segment category labels
        data = [go.Bar(name=segment, x=bar_columns,y=list(df_plot[df_plot[input.category()]==segment]['count'].values), customdata=[input.category()], customdatasrc='A') for segment in bar_segments]
        fig = go.Figure(data)
        fig.update_layout(barmode="stack")
        fig = go.FigureWidget(fig)

        ##### TRYING TO CAPTURE CLICK EVENT ON A BAR SEGMENT HERE #####
        fig.data[0].on_click(
            filter_fn
            )
        return fig
    
    @render.text
    def total_rows():
        return "Total Rows: "+str(df_filtered_stage1().shape[0])

    @render.table
    def table_view():
        df_this=df_summarized()
        return df_filtered_stage1()

app = App(app_ui, server)

It looks like I can only capture click events by the trace. I'm wondering if there is a better way than what I'm doing above because fig.data has 3 bar traces during runtime when viewing by "Species" (Gentoo, Chinstrap, and Adelie) and it seems that each bar trace is what gets an on_click() method.


Solution

  • Below is a variant where I changed several parts, here are the most important ones:

    • In your example filter_fn never gets invoked because it does not depend on any reactive expression. Shiny has no need to call it.

      You can do it like this: We define a reactive.value called filter which contains the filter information from the on_click event (filter = reactive.value({})) and the function

      def setClickedFilterValues(trace, points, selector):
        if not points.point_inds:
            return
        filter.set({"Year": points.xs, "Category": points.trace_name})
      

      which is set to be the on_click event on every trace:

      for trace in fig.data:
        trace.on_click(setClickedFilterValues)
      

      The if clause in the function checks whether you are on the clicked trace, if not, stop. filter then contains the right values. An important point here is that the function does not get a reactive decorator like @reactive.calc. This is not needed because we only update the value.

    • I modified df_filtered_stage2() to take account on filter, this calculates the data frame for the output displayed below on the app.

      @reactive.calc
      def df_filtered_stage2():
        if filter.get() == {}:
            return df_filtered_stage1()
        df_filtered_st2 = df_filtered_stage1()[
            (df_filtered_stage1()['year'].isin(filter.get()['Year'])) &
            (df_filtered_stage1()[input.category()].isin([filter.get()['Category']]))
        ] 
        return df_filtered_st2 
      
    • Similar to above and your R app one can implement the on_hover event, this is included below.

    It looks like this:

    enter image description here

    # Load data and compute static values
    from shiny import App, reactive, render, ui
    from shinywidgets import output_widget, render_widget, render_plotly
    from plotnine import ggplot, aes, geom_bar
    import plotly.graph_objects as go
    import palmerpenguins
    from plotly.callbacks import Points, InputDeviceState
    points, state = Points(), InputDeviceState()
    
    
    df_penguins = palmerpenguins.load_penguins()
    
    dict_category = {'species':'Species','island':'Island','sex':'Gender'}
    
    def filter_shelf():
        return ui.card(
            ui.card_header(
                "Filters", 
                align="center",
            ),
            # Gender Filter
            ui.input_checkbox_group(
                'sex_filter', 
                label='Gender', 
                choices={value:(value.capitalize() if (type(value)==str) else value) for value in df_penguins['sex'].unique()},
                selected=list(df_penguins['sex'].unique()),
            ),
    
            # Species Filter
            ui.input_checkbox_group(
                'species_filter', 
                label='Species', 
                choices=list(df_penguins['species'].unique()),
                selected=list(df_penguins['species'].unique()),
            ),
    
            # Island Filter
            ui.input_checkbox_group(
                'island_filter', 
                label='Island', 
                choices=list(df_penguins['island'].unique()),
                selected=list(df_penguins['island'].unique()),
            ),
        )
    
    def parameter_shelf():
        return ui.card(
            ui.card_header(
                'Parameters',
                align='center',
            ),
    
            # Category Selector
            ui.input_radio_buttons(
                'category', 
                label = 'View Penguins by:', 
                choices=dict_category,
                selected = 'species',
            ),
        ),
    
    app_ui = ui.page_fluid(
        ui.h1("Palmer Penguins Analysis"),
        ui.layout_sidebar(
            # Left Sidebar
            ui.sidebar(
                filter_shelf(),
                parameter_shelf(),
                width=250,
            ),
            
            # Main Panel
            ui.card( # Plot
                ui.card_header(ui.output_text('chart_title')),
                output_widget('penguin_plot'),
            ),
            ui.output_text_verbatim('hoverInfoOutput'),
            ui.card( # Table
                ui.card_header(ui.output_text('total_rows')),
                ui.column(
                    12, #width
                    ui.output_table('table_view'),
                    style="height:300px; overflow-y: scroll"
                )
            )
        ),
    )
       
    
    def server (input, output, session):
        
        filter = reactive.value({})
        hoverInfo = reactive.value({})
        
        @reactive.calc
        def category():
            '''This function caches the appropriate Capitalized form of the selected category'''
            return dict_category[input.category()]
    
        # Dynamic Chart Title
        @render.text
        def chart_title():
            return "Number of Palmer Penguins by Year, colored by "+category()
    
    
        @reactive.calc
        def df_filtered_stage1():
            '''This function caches the filtered datframe based on selections in the view'''
            return df_penguins[
                (df_penguins['species'].isin(input.species_filter())) &
                (df_penguins['island'].isin(input.island_filter())) &
                (df_penguins['sex'].isin(input.sex_filter()))]
    
        @reactive.calc
        def df_filtered_stage2():
            if filter.get() == {}:
                return df_filtered_stage1()
            df_filtered_st2 = df_filtered_stage1()[
                (df_filtered_stage1()['year'].isin(filter.get()['Year'])) &
                (df_filtered_stage1()[input.category()].isin([filter.get()['Category']]))
            ] 
            return df_filtered_st2 
        
        @reactive.calc
        def df_summarized():
            return df_filtered_stage1().groupby(['year',input.category()], as_index=False).count().rename({'body_mass_g':"count"},axis=1)[['year',input.category(),'count']] 
    
        def setClickedFilterValues(trace, points, selector):
            if not points.point_inds:
                return
            filter.set({"Year": points.xs, "Category": points.trace_name})
            
        def setHoverValues(trace, points, selector):
            if not points.point_inds:
                return
            hoverInfo.set(points)
        
        @render_widget
        def penguin_plot():
            df_plot = df_summarized()
            bar_columns = list(df_plot['year'].unique()) # x axis column labels
            bar_segments = list(df_plot[input.category()].unique()) # bar segment category labels
            data = [go.Bar(name=segment, x=bar_columns,y=list(df_plot[df_plot[input.category()]==segment]['count'].values), customdata=[input.category()], customdatasrc='A') for segment in bar_segments]
            fig = go.Figure(data)
            fig.update_layout(barmode="stack")
            fig = go.FigureWidget(fig)
            
            for trace in fig.data:
                trace.on_click(setClickedFilterValues)
                trace.on_hover(setHoverValues)
            
            return fig
        
        @render.text
        def hoverInfoOutput():
            return hoverInfo.get()
        
        @render.text
        def total_rows():
            return "Total Rows: "+str(df_filtered_stage2().shape[0])
    
        @render.table
        def table_view():
            #df_this=df_summarized()
            return df_filtered_stage2()
    
    app = App(app_ui, server)