Search code examples
pythonpython-3.xmatplotlibplotseaborn

How to embed inset plots generated with seaborn lmplot?


In seaborn sns.lmplot returns FacetGrid object. I would like to plot an inset. Here is a self contained "working" example:

from io import StringIO
import pandas as pd
%matplotlib inline
df_string='time\tsex\tage\tval1\tval2\n1\tM\t18\t0.285837375\t4.402793733\n2\tM\t18\t0.234239365\t2.987464305\n3\tM\t18\t0.820418465\t3.23991295\n4\tM\t18\t0.826027695\t9.707366329\n5\tM\t18\t0.625449525\t2.971235344\n6\tM\t18\t0.485980081\t5.517575471\n7\tM\t18\t0.136163546\t3.620177216\n8\tM\t18\t0.784944053\t5.116294718\n9\tM\t18\t0.981526403\t6.348155198\n10\tM\t18\t0.822237037\t4.682176522\n1\tF\t22\t0.104339381\t5.434133736\n2\tF\t22\t0.788797127\t0.843869877\n3\tF\t22\t0.997986894\t8.765048753\n4\tF\t22\t0.51167857\t2.054679646\n5\tF\t22\t0.328416139\t6.581617426\n6\tF\t22\t0.317804112\t1.584234393\n7\tF\t22\t0.489944956\t8.564257177\n8\tF\t22\t0.207348127\t1.346020575\n9\tF\t22\t0.727347344\t7.487993859\n10\tF\t22\t0.252917798\t8.822904862\n11\tF\t22\t0.690106636\t6.728470474\n12\tF\t22\t0.508078197\t2.489437246\n'
df = pd.read_csv(StringIO(df_string), sep='\t')


# running a moving average
df_tmp = df.groupby(['sex', 'age']).rolling(min_periods=1, window=3, center=True).mean()

df_tmp.plot()

enter image description here

import seaborn as sns
import numpy as np
import matplotlib.pylab as plt
from mpl_toolkits.axes_grid.inset_locator import inset_axes, mark_inset

df_to_plot = df_tmp.reset_index()
g = sns.lmplot(x='time',y='val1',hue="sex",x_estimator=np.mean,height=10, aspect=1,
                   data=df_to_plot, logx= True, legend_out=True, truncate=True)

g.axes[0][0].xaxis.set_label_text('t [sec]')
g.set(yscale="log")   

ax = g.axes[0][0]
axins = inset_axes(ax, "30%", "40%")
g_inset = sns.lmplot(x='time',y='val1',hue="sex",x_estimator=np.mean, data=df_to_plot, legend_out=False)

But I get the following two plots instead of embedding the second one in the inset plot:

enter image description here

enter image description here

In the end I would like to have code that knows where there is a clear white space and put the inset cleanly inside it, something like this example by Christian:

enter image description here


FWIW, my toy set data looks like this in a tabular format:

enter image description here


Solution

  • Since FacetGrid produces its own figure, lmplot cannot be used inside an axes. You will need to plot as many regplots as you need instead.

    from io import StringIO
    import pandas as pd
    
    df_string="""time\tsex\tage\tval1\tval2\n1\tM\t18\t0.285837375\t4.402793733\n
    2\tM\t18\t0.234239365\t2.987464305\n
    3\tM\t18\t0.820418465\t3.23991295\n
    4\tM\t18\t0.826027695\t9.707366329\n
    5\tM\t18\t0.625449525\t2.971235344\n
    6\tM\t18\t0.485980081\t5.517575471\n
    7\tM\t18\t0.136163546\t3.620177216\n
    8\tM\t18\t0.784944053\t5.116294718\n
    9\tM\t18\t0.981526403\t6.348155198\n
    10\tM\t18\t0.822237037\t4.682176522\n
    1\tF\t22\t0.104339381\t5.434133736\n
    2\tF\t22\t0.788797127\t0.843869877\n
    3\tF\t22\t0.997986894\t8.765048753\n
    4\tF\t22\t0.51167857\t2.054679646\n
    5\tF\t22\t0.328416139\t6.581617426\n
    6\tF\t22\t0.317804112\t1.584234393\n
    7\tF\t22\t0.489944956\t8.564257177\n
    8\tF\t22\t0.207348127\t1.346020575\n
    9\tF\t22\t0.727347344\t7.487993859\n
    10\tF\t22\t0.252917798\t8.822904862\n
    11\tF\t22\t0.690106636\t6.728470474\n
    12\tF\t22\t0.508078197\t2.489437246\n"""
    df = pd.read_csv(StringIO(df_string), sep='\t')
    
    
    import seaborn as sns
    import numpy as np
    import matplotlib.pylab as plt
    from mpl_toolkits.axes_grid1.inset_locator import inset_axes
    
    fig, ax = plt.subplots()
    
    for (n, grp) in df.groupby("sex"):
        sns.regplot(x='time',y='val1', x_estimator=np.mean,
                       data=grp, logx= True, truncate=True)
    
    ax.xaxis.set_label_text('t [sec]')
    ax.set(yscale="log")   
    
    axins = inset_axes(ax,  "30%", "40%" ,loc="lower right", borderpad=3)
    
    for (n, grp) in df.groupby("sex"):
        sns.regplot(x='time',y='val1', x_estimator=np.mean,
                       data=grp, truncate=True, ax=axins)
    
    
    plt.show()
    

    enter image description here