I am plotting a dataframe as a pairplot and fitting lowess lines using regplot in seaborn. I have used hue = "variable", but would like to see the lowess line fitted on the whole data (non-hue) in the same pairplot as well. I would appreciate any help or tips on this question.
pg = sns.pairplot(df, hue="variable", plot_kws={'alpha':0.015})
a = pg.map(sns.regplot, lowess=True, scatter_kws={'alpha':0.03})
Let's replace sns.regplot
by a custom function, e.g. plot_extra
.
pg.map
will call plot_extra
once for each hue value for each subplot. It will pass the following parameters:
label
: the name of the hue value of this callx
and y
: the columns for this subplot, restricted to the current hue valuepg.map
To draw the regplot
for the complete dataframe, the dataframe can be given as extra parameter. To prevent that the same function will be executed again and again for each of the hue values, the function could test on the label and only go further for one specific label. The .name
of x
and y
can be used to indicate which columns to plot. .map()
will call for each of the subplots, while .map_offdiag()
will be restricted to the subplots that aren't on the diagonal.
The legend can be updated: extract the information from the existing pairplot
legend and add a reference to the regplot line.
Here is an example using the standard iris data set.
import matplotlib.pyplot as plt
import seaborn as sns
def plot_extra(x, y, **kwargs):
if kwargs['label'] == first_label:
sns.regplot(data=kwargs['data'], x=x.name, y=y.name, lowess=True, scatter=False, color=kwargs['color'])
df = sns.load_dataset('iris')
first_label = df['species'][0]
pg = sns.pairplot(df, hue='species', plot_kws={'alpha': 0.5}, palette='turbo')
pg.map_offdiag(plot_extra, color='crimson', data=df)
legend_dict = {h.get_label(): h for h in pg.legend.legendHandles} # the existing legend items
legend_dict['lowess regression'] = pg.axes[0, 1].lines[
0] # add the first line object of a non-diagonal ax to the legend
pg.legend.remove() # remove existing legend
pg.add_legend(legend_dict, label_order=legend_dict.keys(), title='') # create the new legend
plt.show()