Search code examples
pythonfunctional-programmingnested-functiondeclarative-programming

Python nested currying


I was trying to solve a codewars problem here, and I got a bit stuck. I believe I should be using nested currying in Python.

Let us just take the case of add. Let us constrain the problem even more, and just get nested add working on the right hand side, i.e. write an add function such that


print((add)(3)(add)(5)(4))

prints 12.

It should be possible to nest it as deep as required, for e.g. I want

print((add)(add)(3)(4)(add)(5)(6))

should give me 18.

What I have done so far -

My initial attempt is to use the following nested function -

def add_helper():
    current_sum = 0

    def inner(inp):
        if isinstance(inp, int):
            nonlocal current_sum
            current_sum += inp
            print(f"current_sum = {current_sum}")

        return inner

    return inner


add = add_helper()

However, this does not do the trick. Instead, I get the following output, for when I do something like print((add)(add)(3)(4)(add)(5)(6))

current_sum = 3
current_sum = 7
current_sum = 12
current_sum = 18
<function add_helper.<locals>.inner at 0x...>

Does anyone know how I have to change my function so that I just return 18, because the function will know it is "done"?

Any help will be appreciated!

UPDATE

After looking at Bharel's comments, I have the following so far -


def add_helper():
    val = 0
    ops_so_far = []
    def inner(inp):
        if isinstance(inp, int):
            nonlocal val
            val += inp
            return inner
        else:
            ops_so_far.append(("+", val))
            inp.set_ops_so_far(ops_so_far)
            return inp
    def set_ops_so_far(inp_list):
        nonlocal ops_so_far
        ops_so_far = inp_list

    def get_val():
        nonlocal val
        return val

    def get_ops_so_far():
        nonlocal ops_so_far
        return ops_so_far

    inner.get_ops_so_far = get_ops_so_far
    inner.set_ops_so_far = set_ops_so_far
    inner.get_val = get_val
    return inner


def mul_helper():
    val = 1
    ops_so_far = []
    def inner(inp):

        if isinstance(inp, int):
            nonlocal val
            val *= inp
            return inner
        else:
            ops_so_far.append(("*", val))
            inp.set_ops_so_far(ops_so_far)
            return inp

    def get_ops_so_far():
        nonlocal ops_so_far
        return ops_so_far

    def set_ops_so_far(inp_list):
        nonlocal ops_so_far
        ops_so_far = inp_list

    def get_val():
        nonlocal val
        return val

    inner.get_ops_so_far = get_ops_so_far
    inner.get_val = get_val
    inner.set_ops_so_far = set_ops_so_far

    return inner


add = add_helper()
mul = mul_helper()

and now when I do


res = (add)(add)(3)(4)(mul)(5)(6)
print(res.get_ops_so_far())
print(res.get_val())

I get

[('+', 0), ('+', 7)]
30

Still not sure if this is the correct direction to be following?


Solution

  • This is how I solved it for anyone still looking in the future -

    
    from copy import deepcopy
    
    
    def start(arg):
        def start_evalutaion(_arg, eval_stack, variables):
            new_eval_stack = deepcopy(eval_stack)
            new_variables = deepcopy(variables)
            to_ret = evaluate_stack(_arg, new_eval_stack, new_variables)
    
            if to_ret is not None:
                return to_ret
    
            def inner(inner_arg):
    
                return start_evalutaion(
                    inner_arg, new_eval_stack, new_variables
                )
    
            return inner
        return start_evalutaion(arg, [], dict())
    
    
    add = lambda a, b, variables: variables.get(a, a) + variables.get(b, b)
    sub = lambda a, b, variables: variables.get(a, a) - variables.get(b, b)
    mul = lambda a, b, variables: variables.get(a, a) * variables.get(b, b)
    div = lambda a, b, variables: variables.get(a, a) // variables.get(b, b)
    
    
    def let(name, val, variables):
        variables[name] = val
        return
    
    
    def return_(val, variables):
        return variables.get(val, val)
    
    def evaluate_stack(_arg, eval_stack, variables):
        if callable(_arg):
            if _arg.__name__ == "return_":
                req_args = 1
            else:
                req_args = 2
            eval_stack.append((_arg, req_args, []))
        else:
            while True:
                func_to_eval, req_args, args_so_far = eval_stack[-1]
                args_so_far.append(_arg)
                if len(args_so_far) == req_args:
                    eval_stack.pop()
                    _arg = func_to_eval(*args_so_far, variables)
                    if func_to_eval.__name__ == "return_":
                        return _arg
                    elif _arg is None:
                        break
                else:
                    break
    
    
    

    Passes all testcases