Search code examples
pythonseabornhuepairplot

Seaborn Pairplot with and without hue on top of each other


I am plotting a dataframe as a pairplot and fitting lowess lines using regplot in seaborn. I have used hue = "variable", but would like to see the lowess line fitted on the whole data (non-hue) in the same pairplot as well. I would appreciate any help or tips on this question.

pg = sns.pairplot(df, hue="variable", plot_kws={'alpha':0.015})

a = pg.map(sns.regplot, lowess=True, scatter_kws={'alpha':0.03})


Solution

  • Let's replace sns.regplot by a custom function, e.g. plot_extra.

    pg.map will call plot_extra once for each hue value for each subplot. It will pass the following parameters:

    • label: the name of the hue value of this call
    • x and y: the columns for this subplot, restricted to the current hue value
    • extra parameters as given in the call to pg.map

    To draw the regplot for the complete dataframe, the dataframe can be given as extra parameter. To prevent that the same function will be executed again and again for each of the hue values, the function could test on the label and only go further for one specific label. The .name of x and y can be used to indicate which columns to plot. .map() will call for each of the subplots, while .map_offdiag() will be restricted to the subplots that aren't on the diagonal.

    The legend can be updated: extract the information from the existing pairplot legend and add a reference to the regplot line.

    Here is an example using the standard iris data set.

    import matplotlib.pyplot as plt
    import seaborn as sns
    
    def plot_extra(x, y, **kwargs):
         if kwargs['label'] == first_label:
              sns.regplot(data=kwargs['data'], x=x.name, y=y.name, lowess=True, scatter=False, color=kwargs['color'])
    
    df = sns.load_dataset('iris')
    
    first_label = df['species'][0]
    pg = sns.pairplot(df, hue='species', plot_kws={'alpha': 0.5}, palette='turbo')
    
    pg.map_offdiag(plot_extra, color='crimson', data=df)
    
    legend_dict = {h.get_label(): h for h in pg.legend.legendHandles}  # the existing legend items
    legend_dict['lowess regression'] = pg.axes[0, 1].lines[
         0]  # add the first line object of a non-diagonal ax to the legend
    
    pg.legend.remove()  # remove existing legend
    pg.add_legend(legend_dict, label_order=legend_dict.keys(), title='')  # create the new legend
    plt.show()
    

    sns.pairplot with extra regplot independent of hue