Search code examples
python-3.xnumpyregressionseabornpolynomials

How to plot a polynomial model of multiple categories on a scatter plot


I am working with a standard diamonds dataset and I need to create a graph of the following type:

enter image description here

All I've got at the moment is 1)

import seaborn as sns
import matplotlib.pyplot as plt

# load the data
df = sns.load_dataset('diamonds')

plt.figure(figsize=(12, 8), dpi=200)

scatterplot = sns.scatterplot(data=df, x='carat', y='price', hue='cut', palette='viridis')

sns.lineplot(data=df, x='carat', y='price', hue='cut', palette='viridis', ax=scatterplot)

plt.xlabel('Carat')
plt.ylabel('Price')
plt.title('Scatter Plot of Price vs. Carat with Curved Lines (Viridis Palette)')

plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

plt.show()

(https://i.sstatic.net/oyMw0.png)

2)

plt.figure(figsize=(12, 8), dpi=200)

cut_categories = df['cut'].unique()

for cut in cut_categories:
    data = df[df['cut'] == cut]
    sns.regplot(data=data, x='carat', y='price', scatter_kws={'s': 10}, label=cut)

plt.xlabel('Carat')
plt.ylabel('Price')
plt.title('Regression Plot of Price vs. Carat by Cut')

plt.legend(title='Cut')

plt.show()

enter image description here

How can I get the graph with a polynomial fit?


Solution

  • Data and Imports

    import seaborn as sns
    import numpy as np
    import matplotlib.pyplot as plt
    
    # load data
    df = sns.load_dataset('diamonds')
    
       carat      cut color clarity  depth  table  price     x     y     z
    0   0.23    Ideal     E     SI2   61.5   55.0    326  3.95  3.98  2.43
    1   0.21  Premium     E     SI1   59.8   61.0    326  3.89  3.84  2.31
    2   0.23     Good     E     VS1   56.9   65.0    327  4.05  4.07  2.31
    3   0.29  Premium     I     VS2   62.4   58.0    334  4.20  4.23  2.63
    4   0.31     Good     J     SI2   63.3   58.0    335  4.34  4.35  2.75
    

    np.polyfit and np.poly1d

    # create figure and Axes
    fig, ax = plt.subplots(figsize=(12, 8))
    
    # plot the scatter points
    sns.scatterplot(data=df, x='carat', y='price', hue='cut', palette='viridis', s=10, alpha=0.4, ec='none', ax=ax)
    
    # matching palette colors from viridis
    colors = palette = sns.color_palette('viridis', n_colors=len(df.cut.unique())
    
    # iterate through the unique cuts and matching color
    for cut, color in zip(df.cut.unique(), colors):
    
        # select the data for a given cut
        data = df[df.cut.eq(cut)]
    
        # create the polynomial model
        p = np.poly1d(np.polyfit(data.carat, data.price, 5))
    
        # create x values to pass to the model
        xp = np.linspace(data.carat.min(), data.carat.max(), 1000)
    
        # plot the model
        sns.lineplot(x=xp, y=p(xp), color=color, ax=ax, ls=':')
    
    sns.move_legend(ax, bbox_to_anchor=(1, 0.5), loc='center left', frameon=False)
    

    enter image description here

    sns.lmplot

    • If order is greater than 1, uses numpy.polyfit to estimate a polynomial regression.
    • Separate categories with the hue parameter.
    # plot the polynomial model
    g = sns.lmplot(data=df, x='carat', y='price', hue='cut', palette='viridis', order=5, truncate=True, ci=None, scatter_kws={'s': 10, 'alpha': 1}, height=8, aspect=1.25)
    
    # access the axes to add the manual poly model to
    ax = g.axes.flat[0]
    
    # plot the manual model for comparison
    for cut, color in zip(df.cut.unique(), colors):
        data = df[df.cut.eq(cut)]
        p = np.poly1d(np.polyfit(data.carat, data.price, 5))
        xp = np.linspace(data.carat.min(), data.carat.max(), 1000)
        sns.lineplot(x=xp, y=p(xp), color='k', ax=ax, ls=':', legend=False)
    

    enter image description here

    sns.regplot

    • You must specify order= and set ci=None.
    • Unlike, lmplot there isn't a hue parameter.
    fig, ax = plt.subplots(figsize=(12, 8))
    
    for cut in df['cut'].unique():
        data = df[df['cut'] == cut]
        sns.regplot(data=data, x='carat', y='price', scatter_kws={'s': 10}, label=cut, order=5, ci=None, ax=ax)
    

    enter image description here