Search code examples
pythonlinear-regressionfacet-wrapannotateplotnine

Add regression line equation to facet_wrap in PLOTNINE


I'm trying to add equation of the linear equation fitted to the plotted data. I have used geom_smooth and method 'lm' with formula = 'y~x'. add eq to single plot shows how to it to a single plot, however, my query is how to add an equation to a plotnine when dealing with facet_wrap or facet_grid?

workable example is given below:

import plotnine as p9
from scipy import stats
from plotnine.data import mtcars as df

# create plot
plot=(p9.ggplot(data=df, mapping= p9.aes('wt','mpg', color = 'factor(gear)'))
    + p9.geom_point(p9.aes())
    + p9.facet_wrap('~ gear')
    + p9.xlab('Wt')+ p9.ylab(r'MPG')
    + p9.geom_smooth(method='lm', formula = 'y~x', se=False)
     )
print(plot)

The solution given for a single plot by the above quoted site is:

import plotnine as p9
from scipy import stats
from plotnine.data import mtcars as df

#calculate best fit line
slope, intercept, r_value, p_value, std_err = stats.linregress(df['wt'],df['mpg'])
df['fit']=df.wt*slope+intercept
#format text 
txt= 'y = {:4.2e} x + {:4.2E};   R^2= {:2.2f}'.format(slope, intercept, r_value*r_value)
#create plot. The 'factor' is a nice trick to force a discrete color scale
plot=(p9.ggplot(data=df, mapping= p9.aes('wt','mpg', color = 'factor(gear)'))
    + p9.geom_point(p9.aes())
    + p9.xlab('Wt')+ p9.ylab(r'MPG')
    + p9.geom_line(p9.aes(x='wt', y='fit'), color='black')
    + p9.annotate('text', x= 3, y = 35, label = txt))
#for some reason, I have to print my plot 
print(plot)

The same using ggplot in R is discussed here: add eq to facet in R

I am not sure how this can be achieved in plotnine.


Solution

  • The ggplot2 way of achieving that is to create a data frame containing as one column the categories of the faceting variable (make sure to use the same column name as for the main data) and as a second column the labels you want to add to each facet. This data frame could then be used in geom_text to add the labels.

    import plotnine as p9
    import pandas as pd
    from scipy import stats
    from plotnine.data import mtcars as df
    
    def model(df):
        slope, intercept, r_value, p_value, std_err = stats.linregress(df['wt'],df['mpg'])
    
        return pd.Series({'slope' : slope, 'intercept' : intercept, 'r2_value' : r_value**2, 'p_value' : p_value, 'std_err' : std_err })
    
    df_labels = df.groupby('gear').apply(model).reset_index()
    df_labels = df_labels.rename(columns = {'index' : 'gear'})
    
    string = 'y = {:4.2e} x + {:4.2E};   R^2= {:2.2f}'
    df_labels['txt'] = [string.format(*r) for r in df_labels[['slope', 'intercept', 'r2_value']].values.tolist()]
    
    # create plot
    plot=(p9.ggplot(data=df, mapping= p9.aes('wt','mpg', color = 'factor(gear)'))
        + p9.geom_point(p9.aes())
        + p9.geom_text(p9.aes(label = 'txt'), data = df_labels, x = 5.5, y = 35, color = 'black', va = 'top', ha = 'right')
        + p9.facet_wrap('~ gear', ncol = 1)
        + p9.xlab('Wt')+ p9.ylab(r'MPG')
        + p9.geom_smooth(method='lm', formula = 'y~x', se=False)
         )
    print(plot)
    

    enter image description here