Search code examples
pythonseabornpairplotpairgrid

How to customize axes in Seaborn PairGrid


I am trying to customize a Seaborn PairGrid with the following:

  • use log scale axes
  • control axes limits (I want ylims = xlims for all subplots)
  • color/line weight control for major/minor gridlines

I think it can be done by just getting the handles(?), but I am not sure how to do that. This answer is good for JointPlots, but what is the equivalent of ax = g.ax_joint for PairGrids?

I'd also would like to be able to add a 1:1 identity line without having to define a function as the answer here says, if possible.

import pandas as pd
import numpy as np
import seaborn as sns

np.random.seed(0)
df = pd.DataFrame({'x': np.random.rand(10),
                   'y': np.random.rand(10),
                   'z': np.random.rand(10)
                  })

g = sns.PairGrid(df)
g.map_offdiag(sns.scatterplot)

enter image description here


Solution

  • (Update: using axline to draw a diagonal line touching the borders, as suggested in the comments. This function is new since matplotlib 3.3.0. Note that for accuracy reasons -- with a log log axis -- axline still needs a point close to the minimum and another close to the maximum. Those two points also influence the axis limits.)

    To access the axes in a 2D way, you can use g.axes[row, col]. To loop through the axes, you can use for ax in g.axes.flat:.

    You can also use the g.map_...(given_function) functions. These will call the given_function for each of the axes, with as first parameter the data column used for x and as second the one for y. Optional parameters can be given via g.map_...(given_function, param1=..., ...) and will be collected in the kwargs dict. Each time the given_function is called, the current ax will be set (so, it is not an extra parameter). You then can use plt.plot to directly plot on the ax. Or use ax = plt.gca().

    Here is some example code tackling your questions. By drawing the diagonal identity line, the x and y limits will be automatically set equal. Note that by default the limits are all "shared" (with only tick labels at the left and lower subplots).

    import seaborn as sns
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    
    def update_plot(xdata, ydata, xy_min, xy_max, **kwargs):
        plt.yscale('log')
        plt.xscale('log')
        # plt.plot([xy_min, xy_max], [xy_min, xy_max], color='crimson', linestyle='--', linewidth=2)
        plt.axline([xy_min, xy_min], [xy_max, xy_max], color='crimson', linestyle='--', linewidth=2)
        plt.grid(which='major', color='navy', lw=1, ls=':')
        plt.grid(which='minor', color='navy', lw=0.2, ls=':')
    
    np.random.seed(0)
    df = pd.DataFrame({'x': np.random.rand(10),
                       'y': np.random.rand(10),
                       'z': np.random.rand(10)})
    g = sns.PairGrid(df)
    g.map_offdiag(sns.scatterplot)
    g.map_offdiag(update_plot, xy_min=df.min().min(), xy_max=df.max().max())
    plt.subplots_adjust(left=0.1)  # a bit more room at the left for the labels
    plt.show()
    

    example plot