Search code examples
pythonpython-3.xdrop-down-menuplotlyplotly-python

‘update_layout’ in plotly not working after a few clicks


I have a dropdown menu where I can choose the x- and y-axis variables for a scatter plot. Moreover, a categorical variable can be selected in the menu indicating how to color the points. This seems to work for a few clicks, but then I am getting ‘%{customdata[0]}’ in the hover box, and the plot is not correct. I am using plotly 5.9.0 in JupyterLab3. To be able to select the categorical variable for the coloring, I used traces. Below is a reproducible example:

import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go

X = pd.DataFrame({  'num1': [1,2,3,4],
                    'num2': [40,30,20,10],
                    'num3': [0,1,2,3],
                    'cat1': ['A', 'A', 'A', 'B'],
                    'cat2': ['c', 's', 's', 's'],
                    'cat3': ['a', 'b', 'c', 'd']})

numerical_features   = sorted(X.select_dtypes(include=np.number).columns.tolist())
categorical_features = sorted(list(set(X.columns) - set(numerical_features)))

feature_1 = numerical_features[0]
feature_2 = numerical_features[1]

fig = go.Figure()

for categorical_feature_id in range(len(categorical_features)):

    fig.add_traces(list(px.scatter(X, x=feature_1, y=feature_2, color=categorical_features[categorical_feature_id],
                                         labels={feature_1:feature_1, feature_2:feature_2},
                                         hover_data=['cat3', 'num3']).select_traces()))

fig.update_layout(
        xaxis_title=feature_1,
        yaxis_title=feature_2,
        updatemenus=[
            {
                "buttons": [
                    {
                        "label": f"x - {x}",
                        "method": "update",
                        "args": [
                            {"x": [X[x]]},
                            {"xaxis": {"title": x}},
                        ],
                    }
                    for x in numerical_features
                ]
            },
            {
                "buttons": [
                    {
                        "label": f"y - {y}",
                        "method": "update",
                        "args": [
                            {"y": [X[y]]},
                            {"yaxis": {"title": y}}
                        ],
                    }
                    for y in numerical_features
                ],
                "y": 0.8,
            },
            {
                "buttons": [
                    {
                        "label": f"z - {categorical_features[categorical_feature_id]}",
                        "method": "update",
                        "args": [{'visible':    [False if (i<categorical_feature_id) or (i>categorical_feature_id) else True for i in range(len(categorical_features))]},
                                 {'showlegend': [False if (i<categorical_feature_id) or (i>categorical_feature_id) else True for i in range(len(categorical_features))]}]
                    }
                    for categorical_feature_id in range(len(categorical_features))
                ],
                "y": 0.6,
            }])
fig.show()

An example of how the figure looks after a few updates

A similar issue has been discussed for R:

Dropdown menu for changing the color attribute of data in scatter plot (Plotly R)

I would be grateful for any help.


Solution

  • Hello I updated your code a little bit.

    I think here data transformation is a must.

    I switched the px.scatter with go.Scatter() now the hover box seems to work.

    I hope this does the trick.

    import pandas as pd
    import numpy as np
    import seaborn as sns
    import plotly.graph_objects as go
    from collections import defaultdict
    
    X = pd.DataFrame({  'num1': [1,2,3,4],
                        'num2': [40,30,20,10],
                        'num3': [0,1,2,3],
                        'cat1': ['A', 'A', 'A', 'B'],
                        'cat2': ['c', 's', 's', 's'],
                        'cat3': ['a', 'b', 'c', 'd']})
    
    numerical_features   = sorted(X.select_dtypes(include=np.number).columns.tolist())
    categorical_features = sorted(list(set(X.columns) - set(numerical_features)))
    
    dfs_list = []
    
    for categorical_feature in categorical_features:
        features = numerical_features.copy()
        features.append(categorical_feature)
        dfs_list.append(X[features].copy())
    
    unique_classes = list(pd.unique(X[categorical_features].values.ravel()))
    dict_cat_color = {unique_classes[i] : 'rgb' + str(sns.color_palette(n_colors=len(unique_classes))[i])
                      for i in range(len(unique_classes))}
    
    features_w_cat = numerical_features.copy()
    features_w_cat.append('cat')
    
    for x in dfs_list:
        x.columns  = features_w_cat
        x["color"] = x.cat.map(dict_cat_color)
    
    orDict = defaultdict(list)
    
    fig = go.Figure()
    
    # Workaround for the legend: Adding empty scatter plots with customized color and text
    
    for key in dict_cat_color.keys():
    
        fig.add_traces(go.Scatter(
            x             = [None],
            y             = [None],
            name          = key,
            marker_color  = dict_cat_color[key],
            mode          = "markers",
            showlegend    = True
        ))
        
        for categorical_feature in categorical_features:
            
            if key in X[categorical_feature].unique():
                orDict[categorical_feature].append(True)
            else:
                orDict[categorical_feature].append(False)
    
    for index,df in enumerate(dfs_list):
        
        fig.add_traces(go.Scatter(
            x             = [None],
            y             = [None],
            marker_color  = df["color"],
            customdata    = df.loc[:, ["num1","num2","num3","cat"]],
            mode          = "markers",
            hovertemplate = 'num1=%{customdata[0]}<br>num2=%{customdata[1]}<br>num3=%{customdata[2]}<br>cat=%{customdata[3]}',
            showlegend    = False
        ))
    
    fig.update_layout(
            xaxis_title = '',
            yaxis_title = '',
            updatemenus = [
                {
                    "buttons": [
                        {
                            "label": f"x - {x}",
                            "method": "update",
                            "args": [
                                {"x": [X[x]]},
                                {"xaxis": {"title": x}},
                            ],
                        }
                        for x in numerical_features
                    ]
                },
                {
                    "buttons": [
                        {
                            "label": f"y - {y}",
                            "method": "update",
                            "args": [
                                {"y": [X[y]]},
                                {"yaxis": {"title": y}}
                            ],
                        }
                        for y in numerical_features
                    ],
                    "y": 0.8,
                },
                {
                    "buttons": [
                        {
                            "label": f"z - {categorical_features[categorical_feature_id]}",
                            "method": "update",
                            "args": [{'visible': orDict[categorical_features[categorical_feature_id]] + [False if (i<categorical_feature_id) or (i>categorical_feature_id) else True for i in range(len(categorical_features))]}],
                        }
                        for categorical_feature_id in range(len(categorical_features))
                    ],
                    "y": 0.6,
                }])
    
    fig.show()