Search code examples
pythonscipyinterpolationspline

Batch make_smoothing_spline in scipy


In scipy, the function scipy.interpolate.make_interp_spline() can be batched since its x argument must be one-dimensional with shape (m,) and its y argument can have shape (m, ...).

However, the function scipy.interpolate.make_smoothing_spline() only accepts a y argument of shape (m,).

Is there a simple way to batch the behavior of make_smoothing_spline() so it has the same behavior as make_interp_spline()?

I was thinking of using numpy.vectorize(), but here I'm not batching operations on an array, I need a single function as output.

I guess I could just implement a loop and make a nested list of splines, but I was wondering if there would be a neater way.

Probably some combination of decorators but I'm twisting my brain in knots...

EDIT: Developers seem to be aware of this issue here.


Solution

  • The PR that added batch support to make_smoothing_spline happened to be merged a few hours before this post. https://github.com/scipy/scipy/pull/22484

    The feature will be available in SciPy 1.16, or you can get it early in the next nightly wheels. https://anaconda.org/scientific-python-nightly-wheels/scipy

    See also the BatchSpline class used in the tests of that PR.

    class BatchSpline:
        # BSpline-like class with reference batch behavior
        def __init__(self, x, y, axis, *, spline, **kwargs):
            y = np.moveaxis(y, axis, -1)
            self._batch_shape = y.shape[:-1]
            self._splines = [spline(x, yi, **kwargs) for yi in y.reshape(-1, y.shape[-1])]
            self._axis = axis
    
        def __call__(self, x):
            y = [spline(x) for spline in self._splines]
            y = np.reshape(y, self._batch_shape + x.shape)
            return np.moveaxis(y, -1, self._axis) if x.shape else y
    
        def integrate(self, a, b, extrapolate=None):
            y = [spline.integrate(a, b, extrapolate) for spline in self._splines]
            return np.reshape(y, self._batch_shape)
    
        def derivative(self, nu):
            res = copy.deepcopy(self)
            res._splines = [spline.derivative(nu) for spline in res._splines]
            return res
    
        def antiderivative(self, nu):
            res = copy.deepcopy(self)
            res._splines = [spline.antiderivative(nu) for spline in res._splines]
            return res