Search code examples
pythonfunctionchaining

Conditional function chaining in Python


Imagine that there is a function g I want to implement by chaining sub-functions. This can be easily done by:

def f1(a):
    return a+1

def f2(a):
    return a*2

def f3(a):
    return a**3

g = lambda x: f1(f2(f3(x)))

However, now consider that, which sub-functions will be chained together, depends on conditions: specifically, user-specified options which are known in advance. One could of course do:

def g(a, cond1, cond2, cond3):

    res = a
    if cond1:
        res = f3(res)
    if cond2:
        res = f2(res)
    if cond3:
        res = f1(res)
    return res

However, instead of dynamically checking these static conditions each time the function is called, I assume that it's better to define the function g based on its constituent functions in advance. Unfortunately, the following gives a RuntimeError: maximum recursion depth exceeded:

g = lambda x: x
if cond1:
    g = lambda x: f3(g(x))
if cond2:
    g = lambda x: f2(g(x))
if cond3:
    g = lambda x: f1(g(x))

Is there a good way of doing this conditional chaining in Python? Please note that the functions to be chained can be N, so it's not an option to separately define all 2^N function combinations (8 in this example).


Solution

  • I found one solution with usage of decorators. Have a look:

    def f1(x):
        return x + 1
    
    def f2(x):
        return x + 2
    
    def f3(x):
        return x ** 2
    
    
    conditions = [True, False, True]
    functions = [f1, f2, f3]
    
    
    def apply_one(func, function):
        def wrapped(x):
            return func(function(x))
        return wrapped
    
    
    def apply_conditions_and_functions(conditions, functions):
        def initial(x):
            return x
    
        function = initial
    
        for cond, func in zip(conditions, reversed(functions)):
            if cond:
                function = apply_one(func, function)
        return function
    
    
    g = apply_conditions_and_functions(conditions, functions)
    
    print(g(10)) # 101, because f1(f3(10)) = (10 ** 2) + 1 = 101
    

    The conditions are checked only once when defining g function, they are not checked when calling it.