I'm looking for a "clean" way to remove the trendline from the marginal-distribution subplot created using plotly-express. I know it's a bit unclear, so please look at the following example:
Generating some fake data:
np.random.seed(42)
data = pd.DataFrame(np.random.randint(0, 100, (100, 4)), columns=["feature1", "feature2", "feature3", "feature4"])
data["label"] = np.random.choice(list("ABC"), 100)
data["is_outlier"] = np.random.choice([True, False], 100)
Creating a scatter plot with both marginal
and trendline
options:
fig = px.scatter(
data, x="feature1", y="feature2",
color="label", symbol="is_outlier", symbol_map={True: "x", False: "circle"},
log_x=False, marginal_x="box",
log_y=False, marginal_y="box",
trendline="ols", trendline_scope="overall", trendline_color_override='black',
trendline_options=dict(log_x=False, log_y=False),
)
This yields a figure with a trendline in all 3 panels:
I looked into the fig.data
struct and found that the trendlines are the last 3 objects in it, and the last 2 are the lines appearing in the top & right panels. Removing those objects from the structs will result in removing the lines from those panels. Seen here:
fig2 = copy.deepcopy(fig)
fig2.data = fig2.data[:-2]
This creates a new issue, because it also removes trendline
from the legend, which is not a behavior I'm happy with. So I need to first update the 3rd-to-last object (main panel's trendline) to have showlegend=True
attribute:
fig3 = copy.deepcopy(fig)
fig3.data[-3].showlegend = True
fig3.data = fig3.data[:-2]
This finally gives me the figure I wanted:
So I do have a solution, but it requires "manhandling" the fig
object.
Is there a better, cleaner way of achieving the same final figure?
###############
Full code:
import copy
import numpy as np
import pandas as pd
import plotly.io as pio
import plotly.express as px
pio.renderers.default = "browser"
np.random.seed(42)
data = pd.DataFrame(np.random.randint(0, 100, (100, 4)), columns=["feature1", "feature2", "feature3", "feature4"])
data["label"] = np.random.choice(list("ABC"), 100)
data["is_outlier"] = np.random.choice([True, False], 100)
fig = px.scatter(
data, x="feature1", y="feature2",
color="label", symbol="is_outlier", symbol_map={True: "x", False: "circle"},
log_x=False, marginal_x="box",
log_y=False, marginal_y="box",
trendline="ols", trendline_scope="overall", trendline_color_override='black',
trendline_options=dict(log_x=False, log_y=False),
)
fig.show()
fig2 = copy.deepcopy(fig)
fig2.data = fig2.data[:-2]
fig2.show()
fig3 = copy.deepcopy(fig)
fig3.data[-3].showlegend = True
fig3.data = fig3.data[:-2]
fig3.show()
You can use the Figure.update_traces()
method that allows to apply specific properties to all traces that satisfy the selector
parameter (there is no function to remove traces, but we can hide them using the visible
property).
All OLS trendline traces share the same name
("Overall Trendline", which is given by the trendline_scope
), and you can use their xaxis
(or yaxis
) reference to distinguish between them (ie. "x"
refers to the xaxis of the main subplot, "x2"
and "x3"
refer respectively to the right and the top axes/subplots).
For example :
np.random.seed(42)
data = pd.DataFrame(np.random.randint(0, 100, (100, 4)), columns=["feature1", "feature2", "feature3", "feature4"])
data["label"] = np.random.choice(list("ABC"), 100)
data["is_outlier"] = np.random.choice([True, False], 100)
fig = px.scatter(
data, x="feature1", y="feature2",
color="label", symbol="is_outlier", symbol_map={True: "x", False: "circle"},
log_x=False, marginal_x="box",
log_y=False, marginal_y="box",
trendline="ols", trendline_scope="overall", trendline_color_override='black',
trendline_options=dict(log_x=False, log_y=False),
)
fig.update_traces(visible=False, selector=dict(name='Overall Trendline'))
fig.update_traces(visible=True, showlegend=True, selector=dict(name='Overall Trendline', xaxis='x'))
fig.show()