Search code examples
python-3.xgeneratoryield

Python generator using `yield` not producing the desired result


I am trying to find all the paths in a graph. I found this amazing function that I reproduce here:

def paths(graph, v):
    """Generate the maximal cycle-free paths in graph starting at v.
    graph must be a mapping from vertices to collections of
    neighbouring vertices.

    >>> g = {1: [2, 3], 2: [3, 4], 3: [1], 4: []}
    >>> sorted(paths(g, 1))
    [[1, 2, 3], [1, 2, 4], [1, 3]]
    >>> sorted(paths(g, 3))
    [[3, 1, 2, 4]]

    """
    path = [v]                  # path traversed so far
    seen = {v}                  # set of vertices in path
    def search():
        dead_end = True
        for neighbour in graph[path[-1]]:
            if neighbour not in seen:
                dead_end = False
                seen.add(neighbour)
                path.append(neighbour)
                yield from search()
                path.pop()
                seen.remove(neighbour)
        if dead_end:
            yield list(path)
    yield from search()

However, as the information provided in the function indicates, this function yields paths that have completed, i.e. that hit a dead end. I would like to change the function to yield incomplete paths so sorted(paths(g,1)) would return [[1], [1,2], [1,2,3], [1,2,4], [1,3]].

I tried adding a this line if not dead_end: yield list(path) before the line that says path.pop(). But that ends up yielding some paths twice and won't yield the single node path. The result I got was [[1, 2, 3], [1, 2, 3], [1, 2, 4], [1, 2, 4], [1, 2], [1, 3], [1, 3]] which is not what I want!

Is it possible to modify this code to yield "not completed" paths? Could you advice me how to go about it?


Solution

  • You're almost there! First, you'll need to yield your base case.

    yield path
    

    You'll need to do this before you even start iterating since getting to your first yield statement means you've already append ed something.

    Second, your duplicates are coming from you second yield statement. Since you are now yield ing as you're iterating, you can remove that one completely. Additionally, since we know if neighbour not in seen: then we haven't reached a dead end and therefore, dead_end is redundant and we can remove that.

    So in summary:

    def paths(graph, v):
        """Generate the maximal cycle-free paths in graph starting at v.
        graph must be a mapping from vertices to collections of
        neighbouring vertices.
    
        >>> g = {1: [2, 3], 2: [3, 4], 3: [1], 4: []}
        >>> sorted(paths(g, 1))
        [[1, 2, 3], [1, 2, 4], [1, 3]]
        >>> sorted(paths(g, 3))
        [[3, 1, 2, 4]]
    
        """
        path = [v]                  # path traversed so far
        seen = {v}                 # set of vertices in path
        yield path
    
        def search():
            for neighbour in graph[path[-1]]:
                if neighbour not in seen:
                    seen.add(neighbour)
                    path.append(neighbour)
                    yield from search()
    
                    yield list(path)
    
                    path.pop()
                    seen.remove(neighbour)
        yield from search()
    
    g = {1: [2, 3], 2: [3, 4], 3: [1], 4: []}
    
    print(sorted(paths(g, 1)))
    print(sorted(paths(g, 3)))
    

    Also, sorted() could be traded out for list() since the first element will be identical for every yielded list.