Search code examples
pythonalgorithmrecursiondata-structuresbinary-tree

Splitting binary tree code does not work on some cases


I am trying to solve a problem of Splitting Binary Tree in python. I will post the problem statement.

Write a function that takes in a Binary Tree with at least one node and checks if that Binary Tree can be split into two Binary Trees of equal sum by removing a single edge. If this split is possible, return the new sum of each Binary Tree, otherwise return 0. You don't need to return the edge that was removed.

The testing code is somewhat like this


import program
import unittest


class TestProgram(unittest.TestCase):
    def test_case_1(self):
        tree = program.BinaryTree(2)
        tree.left = program.BinaryTree(4)
        tree.left.left = program.BinaryTree(4)
        tree.left.right = program.BinaryTree(6)
        tree.right = program.BinaryTree(10)
        tree.right.left = program.BinaryTree(3)
        tree.right.right = program.BinaryTree(3)
        expected = 16
        actual = program.splitBinaryTree(tree)
        self.assertEqual(actual, expected)

and my code for splitBinaryTree is following

# This is an input class. Do not edit.
class BinaryTree:
    def __init__(self, value, left=None, right=None):
        self.value = value
        self.left = left
        self.right = right


def splitBinaryTree(tree, balancesum=0):
    # Write your code here.
    if tree is None:
        return 0
    fullsum = icalculatesum(tree)
    print('fullsum is', fullsum, 'balancesum is', balancesum)
    if fullsum == balancesum:
        return fullsum
    leftsum = icalculatesum(tree.left)
    rightsum = icalculatesum(tree.right)
    if leftsum+tree.value == rightsum+balancesum:
        return fullsum/2
    if rightsum+tree.value == leftsum+balancesum:
        return fullsum/2
    if leftsum+tree.value+balancesum == rightsum:
        return fullsum/2
    if rightsum+tree.value+balancesum == leftsum:
        return fullsum/2
    lefty = splitBinaryTree(tree.left, fullsum-rightsum)
    righty = splitBinaryTree(tree.right, fullsum-leftsum)

    if lefty != 0 or righty !=0:
        return fullsum/2
    return 0


def icalculatesum(node, sumsofar=0):
    if node == None:
        return sumsofar
    sumsofar += node.value
    sumsofar = icalculatesum(node.left, sumsofar)
    sumsofar = icalculatesum(node.right, sumsofar)
    return sumsofar

I am having problem in the following test case. Answer should be 70 but mine is 0

{
  "nodes": [
    {"id": "1", "left": "9", "right": "20", "value": 1},
    {"id": "9", "left": "5", "right": "2", "value": 9},
    {"id": "20", "left": "30", "right": "10", "value": 20},
    {"id": "30", "left": null, "right": null, "value": 30},
    {"id": "10", "left": "35", "right": "25", "value": 10},
    {"id": "35", "left": null, "right": null, "value": 35},
    {"id": "25", "left": null, "right": null, "value": 25},
    {"id": "5", "left": null, "right": null, "value": 5},
    {"id": "2", "left": "3", "right": null, "value": 2},
    {"id": "3", "left": null, "right": null, "value": 3}
  ],
  "root": "1"
}

I am practicing on a leetcode like platform. Please tell me where am I going wrong, and if any information needed I will be glad to try to answer.

I tried removing condition that previously existed

if fullsum%2 !=0:
return 0 

because i am recursing it. Eventually it will reach odd number and return False/0.


Solution

  • There are several issues:

    This comparison:

        if leftsum+tree.value == rightsum+balancesum:
            return fullsum/2
    

    ...seems to cut two edges: it extracts the tree together with its left subtree, so tree is detached from both its parent and its right child. So this is not correct. The next if block has the same problem, but mirrored. These two if blocks should be removed from your code.

    The following two if blocks do it correctly, as there the tree is not detached from its parent, but only from one of its children (but you don't need these blocks either as explained below).

    The recursive calls provide a wrong value for the second parameter: the value of tree itself is not accounted for. So this:

        lefty = splitBinaryTree(tree.left, fullsum-rightsum)
    

    ...should really be this:

        lefty = splitBinaryTree(tree.left, balancesum  + tree.value + rightsum)
    

    Same fix should be applied for righty.

    Note that the corrected expression is the same as you used in the if blocks discussed earlier. And that is why also those if blocks can be omitted. Those cases will be captured one level deeper in recursion.

    The final statements can be simplified to just return lefty or righty.

    With those fixes your code will look like this:

    def splitBinaryTree(tree, balancesum=0):
        if tree is None:
            return 0
        fullsum = icalculatesum(tree)
        if fullsum == balancesum:
            return fullsum
        leftsum = icalculatesum(tree.left)
        rightsum = icalculatesum(tree.right)
        lefty = splitBinaryTree(tree.left, balancesum  + tree.value + rightsum)
        righty = splitBinaryTree(tree.right, balancesum + tree.value + leftsum)
        return lefty or righty
    

    A more efficient algorithm

    The approach you have taken is not efficient. All you really have to do is find a subtree whose sum is half the total sum. The edge that is cut is the one that has the root of this subtree as its child. This pattern fits all possibilities: all edges (parent,child) have the root of a subtree as child.

    Secondly, if you calculate the sums incrementally, from the bottom up, you'll only visit a node once, giving a time complexity of O(𝑛). Once you have collected these sums -- in this bottom up approach -- it is trivial to solve the challenge.

    Code

    # Recursively collect all tree sums into a list, with the root's sum listed first
    def getAllTreeSums(tree):
        if not tree:
            return [0]
        left = getAllTreeSums(tree.left)
        right = getAllTreeSums(tree.right)
        return [tree.value + left[0] + right[0], *left, *right]
    
    def splitBinaryTree(tree):
        tree_sums = getAllTreeSums(tree)
        if tree_sums[0] % 2 == 0 and tree_sums[0] // 2 in tree_sums:
            return tree_sums[0] // 2
        return 0