Search code examples
pythonrecursionlinked-listbinary-tree

Recursive solution to flatten binary tree to linked list


Here is the link for problem description: Flatten Binary Tree to Linked List has:

# class TreeNode(object):
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution(object):
    def flatten(self, root):
        """
        :type root: TreeNode
        :rtype: None Do not return anything, modify root in-place instead.
        """

The solution is:

class Solution:

    def flattenTree(self, node):

        # Handle the null scenario
        if not node:
            return None

        # For a leaf node, we simply return the
        # node as is.
        if not node.left and not node.right:
            return node

        # Recursively flatten the left subtree
        leftTail = self.flattenTree(node.left)

        # Recursively flatten the right subtree
        rightTail = self.flattenTree(node.right)

        # If there was a left subtree, we shuffle the connections
        # around so that there is nothing on the left side
        # anymore.
        if leftTail:
            leftTail.right = node.right
            node.right = node.left
            node.left = None

        # We need to return the "rightmost" node after we are
        # done wiring the new connections.
        return rightTail if rightTail else leftTail

    def flatten(self, root: TreeNode) -> None:
        """
        Do not return anything, modify root in-place instead.
        """

        self.flattenTree(root)

I don't understand this block of code:

        if leftTail:
            leftTail.right = node.right (step 1)
            node.right = node.left      (step 2)
            node.left = None

For example, if the binary tree input is [1, 2, 3], the leftTail after step 1 will be: [2, null, 3]. My naive thought is after step 2, the tree becomes [1, null, 3] but to my surprise, it becomes: [1,null,2,null,3].


Solution

  • Suppose your example with tree [1, 2, 3]:

      1 (node)
     / \
    2   3
    

    And lets check what was done by every step:

    if leftTail:
        leftTail.right = node.right (step 1)
        node.right = node.left      (step 2)
        node.left = None            (step 3)
    

    Step 1:

      1 (node)
     / \
    2   3
     \
      3 (same as above)
    

    Step 2:

      1 (node)
     / \
    2   2 (same as left)
     \   \
      3   3
    

    Step 3:

      1 (node)
       \
        2
         \
          3
    

    So, [1, null, 2, null, 3] is achieved.