Search code examples
pythoniteratorgeneratoryield

Why does "yield" work in a recursive function


When I try to implement a permutation function, I got this code from web.

def permutations(seq):    
    if len(seq) == 1:
        yield seq
    else:
        for i in range(len(seq)):
            perms = permutations(seq[:i] + seq[i+1:])
            for p in perms:
                yield [seq[i], *p]

I tried to unserstand the code but I failed, because I can't understand how to use yield recursively. I know every time I do next(generator), It run the code in body and stop at a yield statement. But how can it reach base condition if I only do next once cause it's obvious I have to do recursion multiple times before I can reach base(len == 1). Based on my understanding, the yield from the bottom line should be return.


Solution

  • The reason this works is as follows: in the loops body of the recursive branch of the condition, first you perform:

                perms = permutations(seq[:i] + seq[i+1:])
    

    This create a local generator object within the current recursion step.

    The next two lines:

                for p in perms:
                    yield [seq[i], *p]
    

    iterate over all the values generated by that generator object, and yields permutation based on them.

    The yield statement does not "penetrate" recursion, but your code collects the value yielded from a recursive call, and then yields results based on them.

    There is in fact a lot of yielding going on here - proportional to the factorial of the length of seq, but that would be expected, based on the size of your output.