Search code examples
pythoncurve-fittingpython-decorators

How to decorate a function that is used for a curve_fit?


I am creating an n-sized polynomial, which I want to fit over my data. Because of the decorator it is supposed to print "test" every time the function is run, however, because of the decorator, the curve_fit no longer knows how many parameters it should supply (I get the error: ValueError: Unable to determine number of fit parameters.). Is there a way to make the curve_fit "know" how many parameters it should supply?

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit


def print_dec(func_):                           # prints test before function
    def inner(*args, **kwargs):
        print("test")

        returned_value = func_(*args, **kwargs)

        return returned_value

    return inner


n = 10
func_str = "lambda x"                           # function creation
for i in range(n+1):
    func_str += f", var_{i}"
func_str += ": var0"
for i in range(1, n+1):
    func_str += f" + var_{i}*x**{i}"
print(func_str)
func = eval(func_str)

func = print_dec(func)

x_data = np.linspace(0, 20, 1000)
y_data = x_data + np.random.random(1000)        # data is just a straight line with noise

p_opt, p_cov = curve_fit(func, xdata=x_data, ydata=y_data)
print(p_opt)                                    # simple curve_fit

plt.plot(x_data, func(x_data, *p_opt), c='r')   # fitted curve
plt.scatter(x_data, y_data)                     # original data
plt.show()

Solution

  • Your function is dynamically capturing parameters with your *args and **kwargs, but curve_fit needs to know how many parameters your function is using. You can pass p0=[1]*(n+1) in your case to solve the error. You can review the curve_fit implementation/documentation for more details.

    You'll also need to fix the line: func_str += ": var_0" (adding an underscore on your var0 as expected on the lambda function).

    Also, on the implementation, you could use a loop instead of dynamically creating a lambda function, it would be easier to debug, and probably easier to read too