Search code examples
pythonpandasmatplotlibseabornlinear-regression

How to plot multiple linear regressions in the same figure


Given the following:

import numpy as np
import pandas as pd
import seaborn as sns

np.random.seed(365)
x1 = np.random.randn(50)
y1 = np.random.randn(50) * 100
x2 = np.random.randn(50)
y2 = np.random.randn(50) * 100

df1 = pd.DataFrame({'x1':x1, 'y1': y1})
df2 = pd.DataFrame({'x2':x2, 'y2': y2})

sns.lmplot('x1', 'y1', df1, fit_reg=True, ci = None)
sns.lmplot('x2', 'y2', df2, fit_reg=True, ci = None)

This will create 2 separate plots. How can I add the data from df2 onto the SAME graph? All the seaborn examples I have found online seem to focus on how you can create adjacent graphs (say, via the 'hue' and 'col_wrap' options). Also, I prefer not to use the dataset examples where an additional column might be present as this does not have a natural meaning in the project I am working on.

If there is a mixture of matplotlib/seaborn functions that are required to achieve this, I would be grateful if someone could help illustrate.


Solution

  • You could use seaborn's FacetGrid class to get desired result. You would need to replace your plotting calls with these lines:

    # sns.lmplot('x1', 'y1', df1, fit_reg=True, ci = None)
    # sns.lmplot('x2', 'y2', df2, fit_reg=True, ci = None)
    df = pd.concat([df1.rename(columns={'x1':'x','y1':'y'})
                    .join(pd.Series(['df1']*len(df1), name='df')), 
                    df2.rename(columns={'x2':'x','y2':'y'})
                    .join(pd.Series(['df2']*len(df2), name='df'))],
                   ignore_index=True)
    
    pal = dict(df1="red", df2="blue")
    g = sns.FacetGrid(df, hue='df', palette=pal, size=5);
    g.map(plt.scatter, "x", "y", s=50, alpha=.7, linewidth=.5, edgecolor="white")
    g.map(sns.regplot, "x", "y", ci=None, robust=1)
    g.add_legend();
    

    This will yield this plot:

    enter image description here

    Which is if I understand correctly is what you need.

    Note that you will need to pay attention to .regplot parameters and may want to change the values I have put as an example.

    • ; at the end of the line is to suppress output of the command (I use ipython notebook where it's visible).
    • Docs give some explanation on the .map() method. In essence, it does just that, maps plotting command with data. However it will work with 'low-level' plotting commands like regplot, and not lmlplot, which is actually calling regplot behind the scene.
    • Normally plt.scatter would take parameters: c='none', edgecolor='r' to make non-filled markers. But seaborn is interfering the process and enforcing color to the markers, so I don't see an easy/straigtforward way to fix this, but to manipulate ax elements after seaborn has produced the plot, which is best to be addressed as part of a different question.