Search code examples
c++algorithmlinear-regressioncurve-fitting

Sub-quadratic algorithm for fitting a curve with two lines


The problem is to find the best fit of a real-valued 2D curve (given by the set of points) with a polyline consisting of two lines.

A brute-force approach would be to find the "left" and "right" linear fits for each point of the curve and pick the pair with minimum error. I can calculate the two linear fits incrementally while iterating through the points of the curve, but I can't find a way to incrementally calculate the error. Thus this approach yields to a quadratic complexity.

The question is if there is an algorithm that will provide sub-quadratic complexity?

The second question is if there is a handy C++ library for such algorithms?


EDIT For fitting with a single line, there are formulas:

m = (Σxiyi - ΣxiΣyi/N) / (Σxi2 - (Σxi)2/N)
b = Σyi/N - m * Σxi/N

where m is the slope and b is the offset of the line. Having such a formula for the fit error would solve the problem in the best way.


Solution

  • Disclaimer: I don't feel like figuring out how to do this in C++, so I will use Python (numpy) notation. The concepts are completely transferable, so you should have no trouble translating back to the language of your choice.

    Let's say that you have a pair of arrays, x and y, containing the data points, and that x is monotonically increasing. Let's also say that you will always select a partition point that leaves at least two elements in each partition, so the equations are solvable.

    Now you can compute some relevant quantities:

    N = len(x)
    
    sum_x_left = x[0]
    sum_x2_left = x[0] * x[0]
    sum_y_left = y[0]
    sum_y2_left = y[0] * y[0]
    sum_xy_left = x[0] * y[0]
    
    sum_x_right = x[1:].sum()
    sum_x2_right = (x[1:] * x[1:]).sum()
    sum_y_right = y[1:].sum()
    sum_y2_right = (y[1:] * y[1:]).sum()
    sum_xy_right = (x[1:] * y[1:]).sum()
    

    The reason that we need these quantities (which are O(N) to initialize) is that you can use them directly to compute some well known formulae for the parameters of a linear regression. For example, the optimal m and b for y = m * x + b is given by

    μx = Σxi/N
    μy = Σyi/N
    m = Σ(xi - μx)(yi - μy) / Σ(xi - μx)2
    b = μy - m * μx
    

    The sum of squared errors is given by

    e = Σ(yi - m * xi - b)2
    

    These can be expanded using simple algebra into the following:

    m = (Σxiyi - ΣxiΣyi/N) / (Σxi2 - (Σxi)2/N)
    b = Σyi/N - m * Σxi/N
    e = Σyi2 + m2 * Σxi2 + N * b2 - 2 * m * Σxiyi - 2 * b * Σyi + 2 * m * b * Σxi
    

    You can therefore loop over all the possibilities and record the minimal e:

    for p in range(1, N - 3):
        # shift sums: O(1)
        sum_x_left += x[p]
        sum_x2_left += x[p] * x[p]
        sum_y_left += y[p]
        sum_y2_left += y[p] * y[p]
        sum_xy_left += x[p] * y[p]
    
        sum_x_right -= x[p]
        sum_x2_right -= x[p] * x[p]
        sum_y_right -= y[p]
        sum_y2_right -= y[p] * y[p]
        sum_xy_right -= x[p] * y[p]
    
        # compute err: O(1)
        n_left = p + 1
        slope_left = (sum_xy_left - sum_x_left * sum_y_left * n_left) / (sum_x2_left - sum_x_left * sum_x_left / n_left)
        intercept_left = sum_y_left / n_left - slope_left * sum_x_left / n_left
        err_left = sum_y2_left + slope_left * slope_left * sum_x2_left + n_left * intercept_left * intercept_left - 2 * (slope_left * sum_xy_left + intercept_left * sum_y_left - slope_left * intercept_left * sum_x_left)
    
        n_right = N - n_left
        slope_right = (sum_xy_right - sum_x_right * sum_y_right * n_right) / (sum_x2_right - sum_x_right * sum_x_right / n_right)
        intercept_right = sum_y_right / n_right - slope_right * sum_x_right / n_right
        err_right = sum_y2_right + slope_right * slope_right * sum_x2_right + n_right * intercept_right * intercept_right - 2 * (slope_right * sum_xy_right + intercept_right * sum_y_right - slope_right * intercept_right * sum_x_right)
    
        err = err_left + err_right
        if p == 1 || err < err_min
            err_min = err
            n_min_left = n_left
            n_min_right = n_right
            slope_min_left = slope_left
            slope_min_right = slope_right
            intercept_min_left = intercept_left
            intercept_min_right = intercept_right
    

    There are probably other simplifications you can make, but this is sufficient to have an O(n) algorithm.