Search code examples
pythonpandasdataframegraphplotly

How can I optimize a plotly graph with updatemenues?


So, I have been using plotly a lot and recently came to use the updatemenus method for adding buttons. I've created several graphs with it, but I find it difficult to find an efficient method to update the args section in updatemenus sections. I have a data frame that is bigger than the example but it’s the same idea, so I have df:

name    unaregate   value   age
input1  in11           2    0
input1  in11           0    1
input1  in11           2    2
input1  in11           3    3
input1  in11           1    4
input1  in12           1    0
input1  in12           3    1
input1  in12           4    2
input1  in12           2    3
input1  in12           3    4
input1  in13           0    0
input1  in13           2    1
input1  in13           4    2
input1  in13           2    3
input1  in13           3    4
input2  in21           3    0
input2  in21           4    1
input2  in21           2    2
input2  in21           1    3
input2  in21           3    4
input2  in22           4    0
input2  in22           0    1
input2  in22           2    2
input2  in22           4    3
input2  in22           0    4
input2  in23           3    0
input2  in23           4    1
input2  in23           0    2
input2  in23           4    3
input2  in23           2    4
input3  in31           3    0
input3  in31           4    1
input3  in31           2    2
input3  in31           4    3
input3  in31           1    4
input3  in32           4    0
input3  in32           0    1
input3  in32           0    2
input3  in32           2    3
input3  in32           1    4
input3  in33           2    0
input3  in33           3    1
input3  in33           0    2
input3  in33           3    3
input3  in33           4    4
input3  in34           2    0
input3  in34           2    1
input3  in34           3    2
input3  in34           4    3
input3  in34           3    4

Here is a super inefficient way to create a data frame similar to this:

df = pd.DataFrame(index=range(5),columns=range(1))
df12 = pd.DataFrame(index=range(5),columns=range(1))
df13 = pd.DataFrame(index=range(5),columns=range(1))
df21 = pd.DataFrame(index=range(5),columns=range(1))
df22 = pd.DataFrame(index=range(5),columns=range(1))
df23 = pd.DataFrame(index=range(5),columns=range(1))
df31 = pd.DataFrame(index=range(5),columns=range(1))
df32 = pd.DataFrame(index=range(5),columns=range(1))
df33 = pd.DataFrame(index=range(5),columns=range(1))
df34 = pd.DataFrame(index=range(5),columns=range(1))
df["name"] = "input1"
df["unaregate"] = "in11"
df["value"] = np.random.randint(0,5, size=len(df))
df["age"] = range(0,len(df))
​
df12["name"] = "input1"
df12["unaregate"] = "in12"
df12["value"] = np.random.randint(0,5, size=len(df12))
df12["age"] = range(0,len(df12))
​
df13["name"] = "input1"
df13["unaregate"] = "in13"
df13["value"] = np.random.randint(0,5, size=len(df13))
df13["age"] = range(0,len(df13))
​
df21["name"] = "input2"
df21["unaregate"] = "in21"
df21["value"] = np.random.randint(0,5, size=len(df21))
df21["age"] = range(0,len(df21))
​
df22["name"] = "input2"
df22["unaregate"] = "in22"
df22["value"] = np.random.randint(0,5, size=len(df22))
df22["age"] = range(0,len(df22))
​
df23["name"] = "input2"
df23["unaregate"] = "in23"
df23["value"] = np.random.randint(0,5, size=len(df23))
df23["age"] = range(0,len(df23))
​
df31["name"] = "input3"
df31["unaregate"] = "in31"
df31["value"] = np.random.randint(0,5, size=len(df31))
df31["age"] = range(0,len(df31))
​
df32["name"] = "input3"
df32["unaregate"] = "in32"
df32["value"] = np.random.randint(0,5, size=len(df32))
df32["age"] = range(0,len(df32))
​
df33["name"] = "input3"
df33["unaregate"] = "in33"
df33["value"] = np.random.randint(0,5, size=len(df33))
df33["age"] = range(0,len(df33))
​
df34["name"] = "input3"
df34["unaregate"] = "in34"
df34["value"] = np.random.randint(0,5, size=len(df34))
df34["age"] = range(0,len(df34))
frames = [df,df12,df13,df21,df22,df23,df31,df32,df33,df34]
df = pd.concat(frames)
df = df.drop([0],axis=1)

This is the method I am employing for the plot:

fig = go.Figure()
names = df.name.unique()
for i in names:
    db = df[df["name"]==i]
    uni = db.unaregate.unique()
    for f in uni:
        fig.add_trace(go.Scatter(
            x=db[db.unaregate==f].age,
            y=db[db.unaregate==f].value,
        connectgaps=False ,visible=False,
        mode='lines', legendgroup=f,name=f))
fig.update_layout(
    template="simple_white",
    xaxis=dict(title_text="age"),
    yaxis=dict(title_text="Value"),
    width=1000, height = 600
)
fig.update_layout(
    updatemenus=[
        dict(
#             type="buttons",
#             direction="down",
            active=0,
#             x=0.7,
#             y=1.2,
#             showactive=True,
            buttons=list(
                [ dict(
                        label="Select name",
                        method="update",
                        args=[
                            {"visible": [False,False,False,
                                         False,False,False,
                                         False,False,False,False
                                         ]},
                        ],
                    ),
                 dict(
                        label="input 1",
                        method="update",
                        args=[
                            {"visible": [True,True,True,
                                         False,False,False,
                                         False,False,False,False
                                         ]},
                        ],
                    ),
                 dict(
                        label="input 2",
                        method="update",
                        args=[
                            {"visible": [False,False,False,
                                         True,True,True,
                                         False,False,False,False
                                         ]},
                        ],
                    ),
                 dict(
                        label="input 3",
                        method="update",
                        args=[
                            {"visible": [False,False,False,
                                         False,False,False,
                                         True,True,True,True
                                         ]},
                        ],
                    ),
]
            ),
#             showactive=True,
        )
    ]
)
fig

In the part were the True’s and False are, is there a way to add those in a loop so when I have more the fifty lines, I do not have to add more than 50 Trues and Fales’s? Any help is Welcomed I just want to be able to run this script for any type of similar data and that the lengths of data do not matter.


Solution

    • data frame creation can be simplified. Using pandas constructor capability with list comprehensions
    • figure / traces creation is far simpler with plotly express
    • core question - dynamically create visible lists
      • the trace is visible if it's in same name group. This where button name corresponds with name level of trace
    import pandas as pd
    import numpy as np
    import plotly.express as px
    
    df = (
        pd.DataFrame(
            [
                {
                    "name": f"input{a}",
                    "unaregate": f"in{a}{b}",
                    "value": np.random.randint(0, 5, 5),
                }
                for a in range(1, 4)
                for b in range(1, 4)
            ]
        )
        .explode("value")
        .pipe(lambda d: d.assign(age=np.random.randint(0, 5, len(d))))
    )
    
    # get valid combinations that will create traces
    combis = df.groupby(["name","unaregate"]).size().index
    
    # for this example - it's far simpler to use plotly express to create traces
    fig = px.line(df, x="age", y="value", color="unaregate").update_traces(visible=False)
    
    # use list comprehesions to populate visible lists
    fig.update_layout(
        updatemenus=[
            {
                "active": 0,
                "buttons": [
                    {
                        "label": "Select name",
                        "method": "update",
                        "args": [{"visible": [False for t in fig.data]}],
                    }
                ]
                + [
                    {
                        "label": n,
                        "method": "update",
                        "args": [{"visible": [n == t for t in combis.get_level_values(0)]}],
                    }
                    for n in combis.get_level_values(0).unique()
                ],
            }
        ],
        template="simple_white"
    )