Search code examples
pythonalgorithmdata-structuresfunctional-programmingternary-tree

Ternary tree paths


I'm solving the following question:

Given a ternary tree (each node of the tree has at most three children), find all root-to-leaf paths.

Example:

enter image description here

My code is as follows:

from __future__ import annotations

import itertools
from collections import deque, Iterable


class TernaryNode:
    def __init__(self, val: int) -> None:
        self.children: list[TernaryNode] = []
        self.val = val

    def __repr__(self) -> str:
        return str(self.val)


def ternary_tree_paths(root: TernaryNode) -> Iterable[Iterable[int]]:
    def _visit(node: TernaryNode) -> Iterable[deque[int]]:
        if not node:
            return []
        if not node.children:
            queue = deque()
            queue.append(node.val)
            return [queue]
        # **
        paths = itertools.chain.from_iterable(map(lambda ch: _visit(ch), node.children))
        for p in paths:
            p.appendleft(node.val)

        return paths

    return _visit(root)

As shown, the code above returns an empty list, where as the desired behavior is [deque([1, 2, 3]), deque([1, 4]), deque([1, 6])]. Note the line with **; if I rewrite that line as paths = [p for ch in node.children for p in _visit(ch)], it works as expected. I'm guessing the problem is because function from_iterable is evaluated lazily, but shouldn't it be forced to evaluate when I'm iterating over the items?


Solution

  • You're exhausting the chain iterable when trying to do the appendleft to each of the items. After that paths is empty.

    You need to make sure the iterable is evaluated only once:

    def ternary_tree_paths(root: TernaryNode) -> Iterable[Iterable[int]]:
        def _visit(node: TernaryNode) -> Iterable[deque[int]]:
            if not node:
                return []
            if not node.children:
                queue = deque()
                queue.append(node.val)
                return [queue]
            paths = itertools.chain.from_iterable(map(_visit, node.children))
            retval = [] # to keep track of results
            for p in paths: # iterate
                p.appendleft(node.val)
                retval.append(p) # add to result
    
            return retval # return result
    
        return _visit(root)
    

    This yields:

    [deque([1, 2, 3]), deque([1, 4]), deque([1, 6])]
    

    For the example in the question.