Search code examples
pythonmatplotlibmathgraphing

Make quadratic regression lines connect seamlessly with matplotlib


I have a basic graph example and I am trying to make all the points be on some sort of curved line. I have an idea on how to about this but am not sure how to implement it or if it is even possible. Below I have a picture of the graph that I have made with the following code:

import matplotlib.pyplot as plt
import numpy as np

# original data
x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
y = [2, 7, 3, 4, 5, 1, 6, 9, 4, 6]

# quadratic regression
for i in range(int((len(x) + len(y)) / 2)):
    sub_x = x[i:i+3]
    sub_y = y[i:i+3]
    
    model = np.poly1d(np.polyfit(sub_x, sub_y, 2))
    polyline = np.linspace(min(sub_x), max(sub_x), 200)
    plt.plot(polyline, model(polyline), color="#6D34D6", linestyle='dashed')

# plot lines
plt.scatter(x, y, color='#FF3FAF')
plt.plot(x, y, color='#FF3FAF', linestyle='solid')
plt.show()

Here is the picture graph that is produced:

output

The question that I have is how do I make all the dotted lines connect seamlessly? I had an idea about averaging each two line segments that contain the same points but I don't know how to go around doing so. Another idea that I had was making some sort of bezier curve that connects all the points but that sounds unnecessarily complicated.

Something like the green line should be the output (sorry for the poorly drawn line):

desired output


Solution

  • You can use scipy.interpolate.interp1d to apply a quadratic interpolation to expand the number of points to, say, 300 length, and then plot a smooth curve.

    import matplotlib.pyplot as plt
    import seaborn as sns
    from scipy.interpolate import interp1d
    
    # original data
    x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
    y = [2, 7, 3, 4, 5, 1, 6, 9, 4, 6]
    
    # quadratic regression
    for i in range(int((len(x) + len(y)) / 2)):
        sub_x = x[i:i+3]
        sub_y = y[i:i+3]
        
        model = np.poly1d(np.polyfit(sub_x, sub_y, 2))
        polyline = np.linspace(min(sub_x), max(sub_x), 200)
        plt.plot(polyline, model(polyline), color="#6D34D6", linestyle='dashed')
    
    
    #Interpolate
    x_new = np.linspace(min(x), max(x), 300)  #<----
    f = interp1d(x, y, kind='quadratic')      #<----
    
    # plot lines
    plt.scatter(x, y, color='#FF3FAF')
    plt.plot(x_new, f(x_new), color='#FF3FAF', linestyle='solid') #<----
    plt.show()
    

    enter image description here