Search code examples
pythonnumpycythonleast-squares

Using NumPy functions in Cython for least-squares fitting of array elements


I need to write a script that will do least-squares fitting, pixel by pixel for a stack of 4 similar 500x500 images. As in, I need to fit the values of a specific pixel location on all four images to a vector of length three, using the same 4x3 matrix for each pixel.

I don't see a way to do this without doing a nested for-loop iteration through each pixel, so I figured cython could speed things up. I have never worked with cython before, but I wrote the following code based on the documentation examples.

The issue is, this is running as slow or slower (~27 s) than a pure python implementation (~25 s).

Does anyone see what is slowing this down? Thanks!

import numpy as np
cimport numpy as np
cimport cython

npint = np.int16
npfloat = np.float64

ctypedef np.int16_t npint_t
ctypedef np.float64_t npfloat_t


@cython.boundscheck(False)
@cython.wraparound(False)

def fourbythree(np.ndarray[npfloat_t, ndim=2] U_mat, np.ndarray[npint_t, ndim=3] G):
    assert U_mat.dtype == npfloat and G.dtype == npint
    cdef unsigned int z = G.shape[0]
    cdef unsigned int rows = G.shape[1]
    cdef unsigned int cols = G.shape[2]
    cdef np.ndarray[npfloat_t, ndim= 3] a  = np.empty((z - 1, rows, cols), dtype=npfloat)
    cdef npfloat_t resid
    cdef unsigned int rank
    cdef Py_ssize_t row, col
    cdef np.ndarray s

    for row in range(rows):
        for col in range(cols):
            a[:, row, col] = np.linalg.lstsq(U_mat, G[:, row, col])[0]
    return a

Solution

  • You shouldn't need the iteration - you can do it all in a single call to lstsq. lstsq allows the second argument to be 2D, in which case the result is also 2D. Your array is 3D however you can readily reshape it to 2D and then reshape the output back (and the reshape is basically free - it should not need to copy the data):

    a = np.linalg.lstsq(U_mat, G.reshape((G.shape[0],-1)))[0]
    a = a.reshape((a.shape[0],G.shape[1],G.shape[2]))
    

    This is all untyped, pure Python code since this isn't really any indexing so I don't expect Cython to help.

    I get something like a 400x speed-up from this (although some of that is because the "one call" version appears to run in parallel and the Cython version doesn't). I think the main reason for the speed-up is the overhead of calling the Python function repeatedly (given it's working on pretty tiny arrays).