Search code examples
data-structuresbinary-treebinary-search-treegraph-theory

Binary search tree problem: transform a BST into a BST where the number of nodes differ by 1 between the left and right


I am doing a problem that requires me to balance any binary search tree, with a criteria that the left and right subtree on each level should have the same amount of nodes or at most 1 node difference

How can I approach this problem? So far I have transformed the tree into a linked list.. and thats it. Im pretty sure thats the first step but not too sure. I have looked everywhere for resources, but the closest thing I could find was day-stout-warren algorithm which balances based on height and not amount of nodes.


Solution

  • Here's a Python solution which works in O(n) time, in-place with O(h) auxiliary space where h is the height of the tree; the only auxiliary data structure is the stack required for the recursive functions.

    It works using a generator function which iterates over the tree while the consumer is changing the tree, but we make local copies of the left and right subtrees before yielding them, so the consumer can reassign those without breaking the generator. (Actually only a local copy of right is really required, but I made local copies of both anyway.)

    class Node:
        def __init__(self, data, left=None, right=None):
            self.data = data
            self.left = left
            self.right = right
        def __repr__(self):
            # display for debug/testing purposes
            def _r(n):
                return '*' if n is None else '(%s ← %r → %s)' % (_r(n.left), n.data, _r(n.right))
            return _r(self)
    
    def balance(root):
        def _tree_iter(node):
            if node is not None:
                # save to local variables, could be reassigned while yielding
                left, right = node.left, node.right
                yield from _tree_iter(left)
                yield node
                yield from _tree_iter(right)
        
        def _helper(it, k):
            if k == 0:
                return None
            else:
                half_k = (k - 1) // 2
                left = _helper(it, half_k)
                node = next(it)
                right = _helper(it, k - half_k - 1)
                node.left = left
                node.right = right
                return node
        
        n = sum(1 for _ in _tree_iter(root))
        return _helper(_tree_iter(root), n)
    

    Example:

    >>> root = Node(4, left=Node(3, left=Node(1, right=Node(2))), right=Node(6, left=Node(5), right=Node(8, left=Node(7), right=Node(9))))
    >>> root
    (((* ← 1 → (* ← 2 → *)) ← 3 → *) ← 4 → ((* ← 5 → *) ← 6 → ((* ← 7 → *) ← 8 → (* ← 9 → *))))
    >>> balance(root)
    (((* ← 1 → *) ← 2 → (* ← 3 → (* ← 4 → *))) ← 5 → ((* ← 6 → *) ← 7 → (* ← 8 → (* ← 9 → *))))