Search code examples
matplotlibseabornjointplot

Seaborn jointplot link x-axis to Matplotlib subplots


Is there a way to add additional subplots created with vanilla Matplotlib to (below) a Seaborn jointplot, sharing the x-axis? Ideally I'd like to control the ratio between the jointplot and the additional plots (similar to gridspec_kw={'height_ratios':[3, 1, 1]}

I tried to fake it by tuning figsize in the Matplotlib subplots, but obviously it doesn't work well when the KDE curves in the marginal plot change. While I could manually resize the output PNG to shrink/grow one of the figures, I'd like to have everything aligned automatically.

I know this is tricky with the way the joint grid is set up, but maybe it is reasonably simple for someone fluent in the underpinnings of Seaborn.

Here is a minimal working example, but there are two separate figures:

import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

Figure 1

diamonds = sns.load_dataset('diamonds')
g = sns.jointplot(
    data=diamonds,
    x="carat",
    y="price",
    hue="cut",
    xlim=(1, 2),
)
g.ax_marg_x.remove()

enter image description here

Figure 2

fig, (ax1, ax2) = plt.subplots(2,1,sharex=True)
ax1.scatter(x=diamonds["carat"], y=diamonds["depth"], color="gray", edgecolor="black")
ax1.set_xlim([1, 2])
ax1.set_ylabel("depth")
ax2.scatter(x=diamonds["carat"], y=diamonds["table"], color="gray", edgecolor="black")
ax2.set_xlabel("carat")
ax2.set_ylabel("table")

enter image description here

Desired output:

enter image description here


Solution

  • You could take the figure created by jointplot(), move its padding (with subplots_adjust()) and add 2 extra axes.

    The example code will need some tweaking for each particular situation.

    import matplotlib.pyplot as plt
    from mpl_toolkits.axes_grid1.inset_locator import inset_axes
    import seaborn as sns
    
    diamonds = sns.load_dataset('diamonds')
    g = sns.jointplot(data=diamonds, x="carat", y="price", hue="cut",
                      xlim=(1, 2), height=12)
    g.ax_marg_x.remove()
    g.fig.subplots_adjust(left=0.08, right=0.97, top=1.05, bottom=0.45)
    
    axins1 = inset_axes(g.ax_joint, width="100%", height="30%",
                        bbox_to_anchor=(0, -0.4, 1, 1),
                        bbox_transform=g.ax_joint.transAxes, loc=3, borderpad=0)
    axins2 = inset_axes(g.ax_joint, width="100%", height="30%",
                        bbox_to_anchor=(0, -0.75, 1, 1),
                        bbox_transform=g.ax_joint.transAxes, loc=3, borderpad=0)
    shared_x_group = g.ax_joint.get_shared_x_axes()
    shared_x_group.remove(g.ax_marg_x)
    shared_x_group.join(g.ax_joint, axins1)
    shared_x_group.join(g.ax_joint, axins2)
    
    axins1.scatter(x=diamonds["carat"], y=diamonds["depth"], color="grey", edgecolor="black")
    axins1.set_ylabel("depth")
    axins2.scatter(x=diamonds["carat"], y=diamonds["table"], color="grey", edgecolor="black")
    axins2.set_xlabel("carat")
    axins2.set_ylabel("table")
    g.ax_joint.set_xlim(1, 2)
    plt.setp(axins1.get_xticklabels(), visible=False)
    plt.show()
    

    sns.jointplot with extra subplots

    PS: How to share x axes of two subplots after they have been created contains some info about sharing axes (although here you simply get the same effect by setting the xlims for each of the subplots).

    The code to position the new axes has been adapted from this tutorial example.