Search code examples
pandasseaborn

adding labels in in pairgrid diagonal


In the attached plot, I want to display labels on the diagonal (red colored). I tried method like map_diag from PairGrid. Is there a method to display the labels?

import pandas as pd
import seaborn as sns

def hide_current_axis(*args, **kwds):
    plt.gca().set_visible(False)

g = sns.PairGrid(df, diag_sharey=False)
g.map_lower(sns.regplot,lowess=True, scatter_kws={"color": "black"}, line_kws={"color": "red"})
g.map_upper(hide_current_axis)
for ax in g.axes.flatten():
    # rotate x axis labels
    ax.set_xlabel(ax.get_xlabel(),  fontsize=30)
    # rotate y axis labels
    ax.set_ylabel(ax.get_ylabel(), fontsize=30)
    # set y labels alignment
    ax.yaxis.get_label().set_horizontalalignment('right')

enter image description here

map_diag "https://seaborn.pydata.org/generated/seaborn.PairGrid.map_diag.html"


Solution

  • Instead of creating a PairGrid and mapping sns.regplot, you could directly create a sns.pairgrid(kind='reg', ...). The parameters for the regplot go into the plot_kws= parameter.

    Instead of removing the upper half of the grid, the option corner=True could be used. Setting the font size for the axes labels could be done via with sns.plotting_context(rc={"axes.labelsize": 20}). The texts for the diagonal could be set via a custom function given to g.map_diag.

    import matplotlib.pyplot as plt
    import seaborn as sns
    
    def set_diag_name(data, label, color):
        ax = plt.gca()
        ax.cla()
        ax.axis('off')
        ax.text(x=0.5, y=0.5, s=data.name.replace('_', '\n'), fontsize=30, color='red',
                ha='center', va='center', transform=ax.transAxes)
    
    iris = sns.load_dataset("iris")
    with sns.plotting_context(rc={"axes.labelsize": 20}):
        g = sns.pairplot(iris, corner=True, kind='reg',
                         plot_kws={'lowess': True,
                                   'scatter_kws': {"color": "black"},
                                   'line_kws': {"color": "red"}})
    g.map_diag(set_diag_name)
    plt.show()
    

    sns.pairplot with text labels on the diagonal