Search code examples
pythonnumpyscipyspline

Python/SciPy: How to get cubic spline equations from CubicSpline


I am generating a graph of a cubic spline through a given set of data points:

import matplotlib.pyplot as plt
import numpy as np
from scipy import interpolate

x = np.array([1, 2, 4, 5])  # sort data points by increasing x value
y = np.array([2, 1, 4, 3])
arr = np.arange(np.amin(x), np.amax(x), 0.01)
s = interpolate.CubicSpline(x, y)
plt.plot(x, y, 'bo', label='Data Point')
plt.plot(arr, s(arr), 'r-', label='Cubic Spline')
plt.legend()
plt.show()

How can I get the spline equations from CubicSpline? I need the equations in the form:

I've attempted various methods to get the coefficients, but they all use data that was obtained using different data other than just the data points.


Solution

  • From the documentation:

    c (ndarray, shape (4, n-1, ...)) Coefficients of the polynomials on each segment. The trailing dimensions match the dimensions of y, excluding axis. For example, if y is 1-d, then c[k, i] is a coefficient for (x-x[i])**(3-k) on the segment between x[i] and x[i+1].

    So in your example, the coefficients for the first segment [x1, x2] would be in column 0:

    • y1 would be s.c[3, 0]
    • b1 would be s.c[2, 0]
    • c1 would be s.c[1, 0]
    • d1 would be s.c[0, 0].

    Then for the second segment [x2, x3] you would have s.c[3, 1], s.c[2, 1], s.c[1, 1] and s.c[0, 1] for y2, b2, c2, d2, and so on and so forth.

    For example:

    x = np.array([1, 2, 4, 5])  # sort data points by increasing x value
    y = np.array([2, 1, 4, 3])
    arr = np.arange(np.amin(x), np.amax(x), 0.01)
    s = interpolate.CubicSpline(x, y)
    
    fig, ax = plt.subplots(1, 1)
    ax.hold(True)
    ax.plot(x, y, 'bo', label='Data Point')
    ax.plot(arr, s(arr), 'k-', label='Cubic Spline', lw=1)
    
    for i in range(x.shape[0] - 1):
        segment_x = np.linspace(x[i], x[i + 1], 100)
        # A (4, 100) array, where the rows contain (x-x[i])**3, (x-x[i])**2 etc.
        exp_x = (segment_x - x[i])[None, :] ** np.arange(4)[::-1, None]
        # Sum over the rows of exp_x weighted by coefficients in the ith column of s.c
        segment_y = s.c[:, i].dot(exp_x)
        ax.plot(segment_x, segment_y, label='Segment {}'.format(i), ls='--', lw=3)
    
    ax.legend()
    plt.show()
    

    enter image description here