Search code examples
pythonnumerical-methods

Variable input inside a function


I am trying to make a dormand-prince type function that i can use for all cases(2,3,4.....variable coupled differential equation). I need to make this function usuable for all cases.

So lotkav is two variable coupled DEQ and lorenz is three variable coupled DEQ. In my written function, I have to evaluate func for different number of input argument as it can be lotkav or lorenz.

bt = np.array([[0.   , 0.        , 0.         ,0.         ,0.       , 0.         , 0.   ], # h
           [1/5. , 1/5.      , 0.         ,0.         ,0.       , 0.         , 0.   ], # k0
           [3/10.,3/40.      , 9/40.      ,0.         ,0.       , 0.         , 0.   ], # k1
           [4/5. ,44/45.     ,-56/15.     ,32/9.      ,0.       , 0.         , 0.   ], # k2
           [8/9. ,19372/6561.,-25360/2187 ,64448/6561 ,-212/729., 0.         , 0.   ], # k3
           [1.  ,9017/3168. ,-355/33.    ,46732/5247 ,49/176   ,-5103/18656. , 0.   ], # k4
           [1   ,35/384     ,0           ,500/1113   ,125/192  ,-2187/6784.  ,11/84.]]) # k5

def lotkav(t, x, y):
    return [x*(a-b*y), -y*(c-d*x)]

def lorenz(t, x, y, z):
    return [sigma*(y-x), x*(rho-z)-y, x*y - beta*z]

def ode45(func, t_span, y0, hi = .001, hmax = .01, hmin = .000001, tol = 0.00000001):
    k = np.zeros(shape=(7,len(y0)))

    h = hi
    vi = y0.copy()
    yp = np.zeros(len(y0))  # fifth order solution
    zp = np.zeros(len(y0))  # sixth order solution
    ts = [ ]             # time at which values are calculated
    ts.append(t_span[0])
    t = t_span[0]
    yr = [y0]
    ri = np.zeros(len(y0))
    sig = signature(func)

    while t <= t_span[1]:
    
        for i in range(7):
            sum = np.zeros(len(y0))
            ti = t + bt[i][0]*h
            for j in range(0,i):
                sum = np.add(sum, bt[i][j+1]*k[j])
        
            for narg in range(len(y0)):
                ri[narg] = vi[narg] + sum[narg]*h
            
            # I want to change the following line so it works for both without if else.
            k[i] = func(ti, ri[0], ri[1]) # for lorenz it will be k[i] = func(ti, ri[0], ri[1], ri[2])

        for i in range(len(y0)):
            yp[i] = vi[i] + ((35/384.)*k[0][i]+(500/1113.)*k[2][i]+(125/192.)*k[3][i]+(-2187/6784.)*k[4][i]+(11/84.)*k[5][i])*h
            zp[i] = vi[i] + ((5179/57600.)*k[0][i]+(7571/16695.)*k[2][i]+(393/640.)*k[3][i]+(-92097/339200.)*k[4][i]+(187/2100)*k[5][i])*h+(1/40)*k[6][i]*h

        err = np.min(np.abs(np.subtract(yp, zp)))

    
        if  0 < err < tol: # if error within tolerance accept result and try for larger step
            yr.append(zp.copy())
            t = t + h
            ts.append(t)
            vi = zp
            h = 2.0*h # increase time step by 20%
            if h > hmax:
                h = hmax
        elif err == 0: # if error becomes 0 then do the same
            yr.append(zp.copy())
            t = t + h
            ts.append(t)
            vi = zp
            h = 2.0*h
            if h > hmax:
                h = hmax
        else:         # error is not within limit reduce the step size
            h = h*(tol*h/(2*(err)))**(1/5.0)
            if h < hmin:
                h = hmin

return ts,yr

tp, yp = ode45(lotkav, [0,50], y0)
t,r = ode45(lorenz, [0,100], y0)

I know *arg could have been used in lorenz or lotkav function parameter but scipy odeint do not require to do so.


Solution

  • Write a wrapper function that takes a variable number of arguments - either 3 or 4 values for the lotkav and lorenz functions respectively.

    Something like this:

    def lotkav(a, b, c):
        return "lotkav"
    
    
    def lorenz(a, b, c, d):
        return "lorenz"
    
    
    def lowrapper(*args):
        match len(args):
            case 3:
                return lotkav(*args)
            case 4:
                return lorenz(*args)
            case _:
                raise Exception("Invalid number of arguments")
    
    ti = 1
    ri = [2, 3]
    print(lowrapper(ti, *ri))
    ri = [2, 3, 4]
    print(lowrapper(ti, *ri))
    

    Output:

    lotkav
    lorenz