Search code examples
pythondata-structuresbinary-search-tree

How to return original root node with updated values after converting tree in function


I am attempting LeetCode problem 538. Convert BST to Greater Tree:

Given the root of a Binary Search Tree (BST), convert it to a Greater Tree such that every key of the original BST is changed to the original key plus the sum of all keys greater than the original key in BST.

Example 1

enter image description here

My logic to solve this question is the following...

Example BST represented as in order traversal: 0 1 2 3 4 5 6 7 8

step 1

Get the in order traversal of the current BST in a list -> [0, 1, 2, 3, 4, 5, 6, 7, 8], where each element represents a node

step 2

  • The idea would be to map a node to it's greater sum

  • create a dictionary nodesToSums which will map a node to it's greater sum

  • create a sums list where each element will represent the greater sum of each node

  • example of above will be the following...

  • inOrderNodes = [0, 1, 2, 3, 4, 5, 6, 7, 8], sums = [36, 36, 35, 33, 30, 26, 21, 15, 8] and a map that maps each node from inOrderNodes to its corresponding greater sum from sums. We know this by their positions of course.

    Here I'm not sure If I can do this by going through the list of nodes, inOrderNodes, or nodesToSums or the original root. In my case I decided I will go perform a BFS on the root and for each node swap the current node value with it's corresponding greater sum as such node.val = nodesToSums[node]

  • at the very end return root

My time complexity I believe is O(n), but I am not sure how to come up with my space complexity. But I do know that I am using O(n) for the inOrderNodes list, another O(n) for nodesToSums map, and another O(n) for the queue in my BFS traversal. I am saying O(n) for each because they all take up the same space when n is the number of nodes in the BST.

Okay, so I thought my logic was sound but I return root at the very end and I print an in order traversal of my tree and I get back the original root node with it's original value. Furthermore, it's the only node I get back. Oops, what am I doing wrong? I am kinda lost at this point.

Here is my code...

class BinaryTree:
    def __init__(self, value):
        self.value = value
        self.left = None
        self.right = None


def convertBST(root):
    if root is None:
        return root

    nodes = inOrder(root)
    nodesToSums = calculateSums(nodes)

    queue = [root]

    while len(queue) > 0:
        node = queue.pop(0)
        node.val = nodesToSums[node]

        if node.left is not None:
            queue.append(node.left)

        if node.right is not None:
            queue.append(node.right)

    return root


def calculateSums(nodes):
    nodesToSums = {}
    sums = [node.value for node in nodes]

    for i in range(len(sums) - 1, -1, -1):
        if i == len(sums) - 1:
            nodesToSums[nodes[i]] = sums[i]
        else:
            sums[i] = sums[i] + sums[i + 1]
            nodesToSums[nodes[i]] = sums[i]

    return nodesToSums


def inOrder(root, nodes=[]):
    if root is None:
        return []

    inOrder(root.left, nodes)
    nodes.append(root)
    inOrder(root.right, nodes)

    return nodes


def printInOrder(root):
    if root is None:
        return

    inOrder(root.left)
    print(root.value, end="")
    inOrder(root.right)


root = BinaryTree(4)
root.left = BinaryTree(1)
root.right = BinaryTree(6)
root.left.left = BinaryTree(0)
root.left.right = BinaryTree(2)
root.right.left = BinaryTree(5)
root.right.right = BinaryTree(7)
root.left.right.right = BinaryTree(3)
root.right.right.right = BinaryTree(8)

convertBST(root)

printInOrder(root)

I am not looking for a solution, there are many out there that try to flex with posting up their own solutions (I can go to LeetCode for this or google the answer). But I am looking for an answer to the specific question I asked - Why is only the root node and just the root node being returned at the end. Please, any other post with the solution to the question that is not a correction to my solution won't help me and I would not want to see a different solution as I am practicing.


Solution

  • Here are some of the issues:

    • The main problem is that your code uses both val and value as node attribute. In the LeetCode problem, the class is defined with a val attribute, but your class has value, and the rest of your code has a mix of val and value references, and as a consequence convertBST does not alter the existing attribute, but adds another one.
    • printInOrder should not call inOrder, but printInOrder. This bug explains why your output only has one value.
    • printInOrder will print everything concatenated. Better use end=" " in the print call

    With the first bullet point fixed, your code will run fine on LeetCode, and with the other two fixed, you'll also get the expected output in your local environment.

    However, I find the use of a dictionary a bit overkill here. Also the use of the queue is overkill, as the order in which you visit the nodes at that point is no longer important; you might as will just iterate over your nodes list.

    The idea of the in-order traversal is great, but it would be more practical if you would traverse the tree with a reversed in-order traversal. Then you can accumulate the running sum and update the nodes at the same time you visit them, greatly simplifying the code.

    Here is a spoiler in case any other visitor would want to know how that would look (I use the LeetCode version of the class):

    class TreeNode: def init(self, val=0, left=None, right=None): self.val = val self.left = left self.right = right def convertBST(root): runningsum = 0 for node in reversedInOrder(root): runningsum += node.val node.val = runningsum return root def reversedInOrder(root): if root: yield from reversedInOrder(root.right) yield root yield from reversedInOrder(root.left)