Search code examples
pythonnumpymultidimensional-arrayscipymodel-fitting

Numpy Polyfit or any fitting to X and Y multidimensional arrays


I have two large multidimensional arrays: Y carries three measurements of half a million objects (e.g. shape=(500000,3)) and X has same shape, but contains position of Y measurements.

At first, I would like for each row, containing an object, to fit a polynomial equation. I know that iterating over arrays are quite slow, but what I'm doing for the moment is:

fit = array([polyfit(X[i],Y[i],deg) for i in xrange(obs.shape[0])])

My question is: is there any possibility of fitting each row of both arrays without explicitly iterating over them?


Solution

  • It is possible to do so without iterate along the first axis. However, your second axis is rather short (being just 3), you can really fit no more than 2 coefficients.

    In [67]:
    
    import numpy as np
    import scipy.optimize as so
    
    In [68]:
    
    def MD_ployError(p, x, y):
        '''if x has the shape of (n,m), y must be (n,m), p must be (n*p, ), where p is degree'''
        #d is no. of degree
        p_rshp=p.reshape((x.shape[0], -1))
        f=y*1.
        for i in range(p_rshp.shape[1]):
            f-=p_rshp[:,i][:,np.newaxis]*(x**i)
        return (f**2).sum()
    
    In [69]:
    
    X=np.random.random((100, 6))
    Y=4+2*X+3*X*X
    P=(np.zeros((100,3))+[1,1,1]).ravel()
    
    In [70]:
    
    MD_ployError(P, X, Y)
    
    Out[70]:
    11012.2067606684
    
    In [71]:
    
    R=so.fmin_slsqp(MD_ployError, P, args=(X, Y))
    Iteration limit exceeded    (Exit mode 9) #you can increase iteration limit, but the result is already good enough.
                Current function value: 0.00243784856039
                Iterations: 101
                Function evaluations: 30590
                Gradient evaluations: 101
    
    In [72]:
    
    R.reshape((100, -1))
    
    Out[72]:
    array([[ 3.94488512,  2.25402422,  2.74773571],
           [ 4.00474864,  1.97966551,  3.02010015],
           [ 3.99919559,  2.0032741 ,  2.99753804],
    ..............................................)