Search code examples
pythonstreamlitpython-polars

Unable to change default selection in streamlit multiselect()


I am new to streamlit apps and facing an issue of not able to select any other option other than default option when using st.multiselect(). I have also tried st.selectbox but faced an issue with that as well.

Demo web app link for issue: https://party-crime-record.streamlit.app/

Web page Code on github: Code link

multiselect option code chunk starts from line number: 46

Snapshot of multiselect code chunk is provided below:

with st.sidebar:
    State_List = df.lazy().select(pl.col('State')).unique().collect().to_series().to_list()

    # State_Selected = st.selectbox(label="Select State",
    #                               options = State_List)

    State_Selected = st.multiselect(label="Select State",
                                    options = State_List,
                                    default = ["Uttar Pradesh"], # Delhi West Bengal  
                                    # default = State_List[-1],
                                    max_selections=1
                                       )

Solution

  • Streamlit is going to rerun the whole script every time something happens on the page so you want to think more functional oriented rather than object oriented and then you want those to be decorated with @st.cache.___ where anything that is serializable (like lists) can be st.cache.data and anything that can't (like lazyframes) is st.cache.resource.

    You also want to think about how the script will run during transitional states. For instance, what does the rest of the script do when the State selector is blank. It seems like streamlit is trying to be smart and autofill your State so you need to explicitly check if State is selected and return nothing if it's not. This is especially the case because you have your state selector as a mutli selection box with a max selection of 1 so for anyone to change the State they first have to make it blank.

    As an aside, if you load a file with scan_parquet you don't have to call .lazy() as it is already lazy.

    This seems to work without any calls to to_pandas (except for px)

    import streamlit as st
    import polars as pl
    import pandas as pd
    import numpy as np
    import plotly.express as px
    
    
    # from: https://youtu.be/lWxN-n6L7Zc
    # StreamlitAPIException: set_page_config() can only be called once per app, and must be called as the first Streamlit command in your script.
    
    
    st.set_page_config(page_title="Election Party Candidate Analysis",
                        layout='wide',
                        initial_sidebar_state="expanded")
    
    
    ############################## GET DATA ##############################
    # @st.cache
    @st.cache_resource
    def get_data():
        df = pl.scan_parquet('https://github.com/johnsnow09/Party_Criminal_Records/raw/main/Elections_Data_Compiled.parquet')
        return df
    
    
    ############################## DATA DONE ##############################
    
    
    
    @st.cache_data
    def get_state_list():
    
        return get_data().select('State').unique().collect().to_series().to_list()
    
    @st.cache_data
    def get_year_list(state):
        years=(
            get_data()
                .filter(pl.col('State').is_in(state))
                .select(pl.col('Year'))
                .unique().collect().to_series()
                .sort()
                .to_list()
        )
        if len(years)>0:
            return years
        else:
            #I just put something here for quickness you might want something else
            return  [2022]
    
        
    
    
    ############################## CREATING HEADER ##############################
    
    header_left,header_mid,header_right = st.columns([1,6,1],gap = "large")
    
    with header_mid:
        # https://docs.streamlit.io/library/get-started/create-an-app
        st.title("Party Criminal Records Analysis")
    
    ############################## HEADER DONE ##############################
    
    
    
    
    
    ############################## FIRST FILTER STATE ##############################
    
    with st.sidebar:
        # State_Selected = st.selectbox(label="Select State",
        #                               options = State_List)
    
        State_Selected = st.multiselect(label="Select State",
                                        options = get_state_list(),
                                        default = "Delhi", # Delhi West Bengal  
                                        #default = get_state_list()[-1],
                                        max_selections=1
                                        )
        
    ############################## FIRST FILTER STATE DONE ##############################
    
    
    
    
    
    ############################## SECOND FILTER YEAR ##############################
    
    
        Year_Selected = st.multiselect(label="Select Election Year",
                                        options=get_year_list(State_Selected),
                                        default=get_year_list(State_Selected)[-1])
        
    ############################## SECOND FILTER YEAR DONE ##############################
        
    
    
    
    
    ############################## FILTERED DATA ##############################
    
    
    
        
    ############################## FILTERED DATA DONE ##############################    
    
    
    
    
    
    ############################## FIRST PLOT ##############################
    @st.cache_resource
    def make_chart(State_Selected, Year_Selected):
        df_selected = (get_data()
                .filter((pl.col('State').is_in(pl.lit(State_Selected))) &
                        (pl.col('Year').is_in(pl.lit(Year_Selected)))
                )
                .collect()
                )
        if df_selected.height==0:
            return None
    
        fig_party_crime_sum = px.bar(df_selected.groupby(['Party']
                                    ).agg(pl.col('Criminal_Case').sum()
                                    ).sort(by='Criminal_Case',descending=True
                                    ).head(18).to_pandas(),
                                    orientation='h',
                                    x='Criminal_Case',y='Party', color="Party",
                                    labels={
                                            "Criminal_Case": "Total Criminal Cases",
                                            "Party": "Election Parties"
                                        },
                                
                                title=f'<b>{State_Selected[0]} - Top 18 Election Parties with Total Criminal Records in 2022 Elections</b>')
        # fig_party_crime_sum.update_yaxes(autorange="reversed")
        fig_party_crime_sum.update_layout(title_font_size=26, height = 600, 
                                            showlegend=False
                                            )
        fig_party_crime_sum.add_annotation(
                                            showarrow=False,
                                            text='Data Source: https://myneta.info/',
                                            xanchor='right',
                                            x=2,
                                            xshift=675,
                                            yanchor='bottom',
                                            y=0.01 #,
                                            # font=dict(
                                            #     family="Courier New, monospace",
                                            #     size=22,
                                            #     color="#0000FF"
                                        )
        return fig_party_crime_sum
    
    if make_chart(State_Selected, Year_Selected):
        st.plotly_chart(make_chart(State_Selected, Year_Selected),use_container_width=True)
    
    ############################## FIRST PLOT DONE ##############################