Search code examples
pythonscipycurve-fittinghigher-order-functions

Dynamically fix function parameters in python?


Consider some generic function of many parameters:

def my_fun(x, a, b, c, d, e, f, g):
    return x, a, b, c, d, e, f, g

I want to create a kind of higher order function, which fixes an arbitrary collection of parameters of my_fun. For example:

fixing_function({"a": 100, "e": 200})

should return a function that looks like this:

def fixed_function(x, b, c, d, f, g):
    return x, 100, b, c, d, 200, f, g

It would be easy to do using exec, but I am sure there must be a more pythonic way to do something like this. The motivation for this question is the use of scipy.optimize.curve_fit(). I need to fit my_fun several times, changing which parameters are fixed and which are fitted. To my knowledge, this precludes using *args and **kwargs in the definition of my_fun.


Solution

  • I don't know if the following is "more pythonic" or more efficient, but at least it does the job in this particular case:

    from inspect import signature
    
    def my_fun(x, a, b, c, d, e, f, g):
        return x, a, b, c, d, e, f, g
    
    def fixing_function(orig_function, fixing_dict):
        orig_params = signature(orig_function).parameters
        remainingArgs = [key for key in orig_params.keys() if key not in fixing_dict]
    
        def new_function(*pos_args):
            zipped = dict(zip(remainingArgs, pos_args))
            new_args = zipped | fixing_dict
            return orig_function(**new_args)
    
        return new_function
    
    fixed_function = fixing_function(my_fun, {"a": 100, "e": 200})
    
    print(fixed_function(1, 2, 3, 4, 5, 6))
    #output: (1, 100, 2, 3, 4, 200, 5, 6)
    

    However, it does not fit OP's purpose, as can be seen from how print(getfullargspec(fixed_function)) results in FullArgSpec(args=[], varargs='pos_args', varkw=None, defaults=None, kwonlyargs=[], kwonlydefaults=None, annotations={}) (nevertheless, it may be useful to other people).


    In case someone else wonders, here is my attempt using eval to satisfy getfullargspec as well:

    from inspect import signature, getfullargspec
    
    def my_fun(x, a, b, c, d, e, f, g):
        return x, a, b, c, d, e, f, g
    
    def fixing_function(orig_function, fixing_dict):
        orig_params = signature(orig_function).parameters
        remainingArgs = [key for key in orig_params.keys() if key not in fixing_dict]
        argSubtitute = [str(fixing_dict[key]) if key in fixing_dict else key for key in orig_params.keys()]
        partialF = f"lambda {','.join(remainingArgs)}:{orig_function.__name__}({','.join(argSubtitute)})"
        return eval(partialF)
    
    fixed_function = fixing_function(my_fun, {"a": 100, "e": 200})  
    
    print(fixed_function(1, 2, 3, 4, 5, 6))
    #output: (1, 100, 2, 3, 4, 200, 5, 6)
    print(getfullargspec(fixed_function))
    #output: FullArgSpec(args=['x', 'b', 'c', 'd', 'e', 'g'], varargs=None, varkw=None, defaults=None, kwonlyargs=[], kwonlydefaults=None, annotations={})
    

    There should be some ways it can be improved.