Search code examples
pythonmatplotlibseaborn

Plotting multiple different plots in one figure using Seaborn


I am attempting to recreate the following plot from the book Introduction to Statistical learning using seaborn enter image description here

I specifically want to recreate this using seaborn's lmplot to create the first two plots and boxplot to create the second. The main problem is that lmplot creates a FacetGrid according to this answer which forces me to hackily add another matplotlib Axes for the boxplot. I was wondering if there was an easier way to achieve this. Below, I have to do quite a bit of manual manipulation to get the desired plot.

seaborn_grid = sns.lmplot('value', 'wage', col='variable', hue='education', data=df_melt, sharex=False)
seaborn_grid.fig.set_figwidth(8)

left, bottom, width, height = seaborn_grid.fig.axes[0]._position.bounds
left2, bottom2, width2, height2 = seaborn_grid.fig.axes[1]._position.bounds
left_diff = left2 - left
seaborn_grid.fig.add_axes((left2 + left_diff, bottom, width, height))

sns.boxplot('education', 'wage', data=df_wage, ax = seaborn_grid.fig.axes[2])
ax2 = seaborn_grid.fig.axes[2]
ax2.set_yticklabels([])
ax2.set_xticklabels(ax2.get_xmajorticklabels(), rotation=30)
ax2.set_ylabel('')
ax2.set_xlabel('');

leg = seaborn_grid.fig.legends[0]
leg.set_bbox_to_anchor([0, .1, 1.5,1])

Which yields enter image description here

Sample data for DataFrames:

df_melt = {
    'education': ['1. < HS Grad', '4. College Grad', '3. Some College', '4. College Grad', '2. HS Grad'],
    'value': [18, 24, 45, 43, 50],
    'variable': ['age', 'age', 'age', 'age', 'age'],
    'wage': [75.0431540173515, 70.47601964694451, 130.982177377461, 154.68529299563, 75.0431540173515]}

df_wage = {
    'education': ['1. < HS Grad', '4. College Grad', '3. Some College', '4. College Grad', '2. HS Grad'],
    'wage': [75.0431540173515, 70.47601964694451, 130.982177377461, 154.68529299563, 75.0431540173515]}

Solution

  • One possibility would be to NOT use lmplot(), but directly use regplot() instead. regplot() plots on the axes you pass as an argument with ax=.

    You lose the ability to automatically split your dataset according to a certain variable, but if you know beforehand the plots you want to generate, it shouldn't be a problem.

    Something like this:

    import matplotlib.pyplot as plt
    import seaborn as sns
    
    fig, axs = plt.subplots(ncols=3)
    sns.regplot(x='value', y='wage', data=df_melt, ax=axs[0])
    sns.regplot(x='value', y='wage', data=df_melt, ax=axs[1])
    sns.boxplot(x='education',y='wage', data=df_melt, ax=axs[2])