I am doing a problem that requires me to balance any binary search tree, with a criteria that the left and right subtree on each level should have the same amount of nodes or at most 1 node difference
How can I approach this problem? So far I have transformed the tree into a linked list.. and thats it. Im pretty sure thats the first step but not too sure. I have looked everywhere for resources, but the closest thing I could find was day-stout-warren algorithm which balances based on height and not amount of nodes.
Here's a Python solution which works in O(n) time, in-place with O(h) auxiliary space where h is the height of the tree; the only auxiliary data structure is the stack required for the recursive functions.
It works using a generator function which iterates over the tree while the consumer is changing the tree, but we make local copies of the left
and right
subtrees before yielding them, so the consumer can reassign those without breaking the generator. (Actually only a local copy of right
is really required, but I made local copies of both anyway.)
class Node:
def __init__(self, data, left=None, right=None):
self.data = data
self.left = left
self.right = right
def __repr__(self):
# display for debug/testing purposes
def _r(n):
return '*' if n is None else '(%s ← %r → %s)' % (_r(n.left), n.data, _r(n.right))
return _r(self)
def balance(root):
def _tree_iter(node):
if node is not None:
# save to local variables, could be reassigned while yielding
left, right = node.left, node.right
yield from _tree_iter(left)
yield node
yield from _tree_iter(right)
def _helper(it, k):
if k == 0:
return None
else:
half_k = (k - 1) // 2
left = _helper(it, half_k)
node = next(it)
right = _helper(it, k - half_k - 1)
node.left = left
node.right = right
return node
n = sum(1 for _ in _tree_iter(root))
return _helper(_tree_iter(root), n)
Example:
>>> root = Node(4, left=Node(3, left=Node(1, right=Node(2))), right=Node(6, left=Node(5), right=Node(8, left=Node(7), right=Node(9))))
>>> root
(((* ← 1 → (* ← 2 → *)) ← 3 → *) ← 4 → ((* ← 5 → *) ← 6 → ((* ← 7 → *) ← 8 → (* ← 9 → *))))
>>> balance(root)
(((* ← 1 → *) ← 2 → (* ← 3 → (* ← 4 → *))) ← 5 → ((* ← 6 → *) ← 7 → (* ← 8 → (* ← 9 → *))))