Search code examples
pythonpandasplotlyinteractive

Plotly change color mapping interactively based on data frame values


I have a data frame and want to create a scatter plot from it. The basic code is :

import plotly.express as px
fig = px.scatter(df, x="X", y="Y", color="size")

I want to interactively change the color mapping by a dropdown menu for all columns in the data frame (beside X, Y). What I need to do is probably to set this one up correctly:

fig.update_layout(active=0,
        buttons=[              
            # my idea would be to use a list comprehension here
            dict(label = col,
                  method = 'restyle',
                  ...?) 
               for col in filter( exclude_columns, df.columns) 
                ])

I found this SO topic change-plotly-express-color-variable-with-button but some things are not clear to me.
They create a px.scatter as well as multiple graph objects:

fig1 = px.scatter(df, "sepal_length", "sepal_width", color="species")
fig2 = px.scatter(df, "sepal_length", "sepal_width", color="petal_length")

fig = go.Figure()
fig.add_trace(go.Scatter(fig1.data[0], visible=True))
fig.add_trace(go.Scatter(fig1.data[1], visible=True))
fig.add_trace(go.Scatter(fig1.data[2], visible=True))

fig.add_trace(go.Scatter(fig2.data[0], visible=False))


fig.update_layout(
    updatemenus=[
        dict(
            buttons=list([
                dict(
                    args=["visible", [True,True,True,False]],
                    label="species",
                    method="restyle"
                ),
                dict(
                    args=["visible", [False,False,False,True]],
                    label="petal length",
                    method="restyle"
                ),
            ]),

Main question:

How can I build this dropdown menu to support an arbitrary range of columns?

Sidequestions

I might have a misconception here but as I have a medium amount of columns. Creating many px.scatter objects sounds inefficient is that really necessary as just the color mapping needs to be changed?

Likewise I don't understand the visible parameter I've seen there and in other tutorials. Why has it length 4; one for every species?


Solution

  • I created a working example below to demonstrate how you can change the color of the trace based on some column value.You could use similar concept to change size as well.

    I don't like to mix Plotly Express and Graph Objects traces as it can be quite confusing. I prefer to use Graph Objects library as it gives the programmer more control of layout and trace properties.

    Please study the example below.

    import plotly.express as px
    import plotly.graph_objects as go
    import pandas as pd
    from numpy import identity
    
    
    df = pd.read_csv("https://raw.githubusercontent.com/mwaskom/seaborn-data/master/iris.csv")
    
    # filter columns for dropdown
    exclude_cols = ['species']
    labels = [col_name for col_name in df.columns if col_name not in exclude_cols]
    # make default figure
    data = [go.Scatter(x=df['sepal_length'], y=df['sepal_width'], visible=False, mode='markers', marker_color=df[c], marker_size=30) for c in labels]
    # first trace is visible by default
    data[0]['visible'] = True
    # create figure
    fig = go.Figure(
        data=data,
        layout=go.Layout(title="Iris plots"))
    
    # update laytout with dropdown menu
    visible = identity(len(labels), dtype=bool)
    buttons= [dict(args=["visible", visible[i, :]], label=l, method="restyle")
                for (i, l) in enumerate(labels)]
    fig.update_layout(
        updatemenus=[
            dict(
                buttons=buttons,
                showactive=True,
                x=0.05,
                xanchor="left",
                y=1.2,
                yanchor="top"
            ),
        ]
    )
    
    fig.update_layout(
        annotations=[
            dict(
                x=0.01,
                xref="paper",
                y=1.16,
                yref="paper",
                align="left",
                showarrow=False),
        ])
    fig.update_layout(xaxis_title_text='sepal_length',
                      yaxis_title_text='sepal_width',)
    fig.show()
    
    

    As an aside, if you need to do more interactivity it's worth learning about Dash. I think it's much easier to do this in Dash than natively in Plotly with the use of callbacks. That is up to you entirely if you are willing and able to invest the time to learn it.