Search code examples
pythonscipylineintersectionspline

Calculating the intersection points of 3 horizontal lines and a cubic spline


I have a problem, similar to those, I am posting now. I wanted to calculate the intersection point between one cubic spline and 3 horizontal lines. For all of these horizontal lines I knew the y-value, and I wanted to find out the corresponding x-value of the intersection. I hope you can help me. I am sure it is very easy to solve for more experienced coders!

import matplotlib.pyplot as plt
from scipy import interpolate
import numpy as np

x = np.arange(0, 10)
y = np.exp(-x**2.0)

spline = interpolate.interp1d(x, y, kind = "cubic")
xnew = np.arange(0, 9, 0.1)
ynew = spline(xnew)   


x1=np.arange(0,10)
y1=1/10*np.ones(10)

x2=np.arange(0,10)
y2=2/10*np.ones(10)

x3=np.arange(0,10)
y3=3/10*np.ones(10)


plt.plot(x,y,'o', xnew, ynew, '-', x1,y1, '-.', x2,y2, '-.', x3,y3, '-.')
plt.show()


for i in range(1,4):
    idx = np.argwhere(np.diff(np.sign(spline-y_i))).flatten()
    list_idx.append(idx)


print(list_idx)

Solution

  • You can use scipy.interpolate.InterpolatedUnivariateSpline's roots() function to find the roots. So first you have to subtract the y-value from the function and find the roots, which gives you the x-value at that particular y-value.

    import matplotlib.pyplot as plt
    from scipy import interpolate
    import numpy as np
    
    x = np.arange(0, 10)
    y = np.exp(-x**2.0)
    
    spline = interpolate.interp1d(x, y, kind = "cubic")
    xnew = np.arange(0, 9, 0.1)
    ynew = spline(xnew)   
    
    
    x1=np.arange(0,10)
    y1=1*np.ones(10)/10
    
    x2=np.arange(0,10)
    y2=2*np.ones(10)/10
    
    x3=np.arange(0,10)
    y3=3*np.ones(10)/10
    
    
    plt.plot(x,y,'o', xnew, ynew, '-', x1,y1, '-.', x2,y2, '-.', x3,y3, '-.')
    plt.show()
    
    
    y_val = 0.2
    func = np.array(y) - y_val
    sub_funct = interpolate.InterpolatedUnivariateSpline(x, func) # to find the roots we need to substract y_val from the function
    root = sub_funct.roots() # find roots here
    print(root)
    

    This prints the x value when y_val=0.2 as,

     [1.36192179]
    

    EDIT

    You can plot the output figure as follows.

    plt.arrow(root, y_val, 0, -y_val, head_width=0.2, head_length=0.06)
    

    enter image description here