Search code examples
pythonplotlyplotly-python

Plotly: Scatter plot with dropdown menu and color by group


I'm trying to make a scatter plot with 2 dropdown menus that select a data column (from a pandas data frame) to be plotted for x and y-axis, but I'm also wanting the points to be colored by a third categorical variable that is fixed (no dropdown needed for this one).

So far, I've been able to create the scatterplot correctly with functional dropdown menus thanks to another post, but I don't know how to colorize it by a third variable. Here's the code so far:

import pandas as pd
import numpy as np
import plotly.graph_objects as go
import plotly.express as px

df = px.data.tips().select_dtypes(['number']) # drop non numeric columns

fig = go.Figure(
go.Scatter(
    x=df['total_bill'],
    y=df['tip'],
    hovertemplate='x: %{x} <br>y: %{y}',
    mode="markers"))
fig.update_layout(
updatemenus=[
    {
        "buttons": [
            {
                "label": f"x - {x}",
                "method": "update",
                "args": [
                    {"x": [df[x]]},
                    {"xaxis": {"title": x}},
                ],
            }
            for x in cols
        ]
    },
    {
        "buttons": [
            {
                "label": f"y - {x}",
                "method": "update",
                "args": [
                    {"y": [df[x]]},
                    {"yaxis": {"title": x}}
                ],
            }
            for x in cols
        ],
        "y": 0.9,
    },
],
margin={"l": 0, "r": 0, "t": 25, "b": 0},
height=700)
fig.show()

Ultimately, I want the same scatterplot, but for points to be colorized by "Category". I can do this using plotly express, but without dropdown menus very easily:

df = px.data.tips()
fig = px.scatter(df, x="total_bill", y="tip", color="smoker", trendline="ols")
fig.show()

Anyone has an idea of how I could achieve this?


Solution

  • As mentioned below, I was able to find a solution for this by creating a new column with hexadecimal color codes corresponding to categorical levels in another column:

    dfx = px.data.tips()
    # create list of columns to iterate over for buttons
    cols = dfx.columns.values.tolist()
    # make list of default plotly colors in hex
    plotly_colors=[
                    '#1f77b4',  # muted blue
                    '#ff7f0e',  # safety orange
                    '#2ca02c',  # cooked asparagus green
                    '#d62728',  # brick red
                    '#9467bd',  # muted purple
                    '#8c564b',  # chestnut brown
                    '#e377c2',  # raspberry yogurt pink
                    '#7f7f7f',  # middle gray
                    '#bcbd22',  # curry yellow-green
                    '#17becf'   # blue-teal
                  ]
    # create dictionary to associate colors with unique categories
    color_dict = dict(zip(dfx['smoker'].unique(),plotly_colors))
    # map new column with hex colors to pass to go.Scatter()
    dfx['hex']= dfx['smoker'].map(color_dict)
    #initialize scatter plot
    fig = go.Figure(
        go.Scatter(
            x=dfx['total_bill'],
            y=dfx['tip'],
            text=dfx['smoker'],
            marker=dict(color=dfx['hex']),
            mode="markers"
        )
    ) 
    # initialize dropdown menus
    fig.update_layout(
        updatemenus=[
            {
                "buttons": [
                    {
                        "label": f"x - {x}",
                        "method": "update",
                        "args": [
                            {"x": [dfx[x]]},
                            {"xaxis": {"title": x}},
                        ],
                    }
                    for x in cols
                ]
            },
            {
                "buttons": [
                    {
                        "label": f"y - {x}",
                        "method": "update",
                        "args": [
                            {"y": [dfx[x]]},
                            {"yaxis": {"title": x}}
                        ],
                    }
                    for x in cols
                ],
                "y": 0.9,
            },
        ],
        margin={"l": 0, "r": 0, "t": 25, "b": 0},
        height=700
    )
    fig.show()
    

    Final scatterplot with points colored by categorical column, "smoker"