Search code examples
pythontreedepth-first-search

Find all paths in a tree that sum to S


Given a binary tree and a number ‘S’, find all paths in the tree such that the sum of all the node values of each path equals ‘S’. Please note that the paths can start or end at any node but all paths must follow direction from parent to child (top to bottom).

This is an answer I found and I do understand that how it works but what I initially thought of was something different I would like to implement and it has been difficult for me to figure out.

class TreeNode:
  def __init__(self, val, left=None, right=None):
    self.val = val
    self.left = left
    self.right = right


def count_paths(root, S):
  return count_paths_recursive(root, S, [])


def count_paths_recursive(currentNode, S, currentPath):
  if currentNode is None:
    return 0

  # add the current node to the path
  currentPath.append(currentNode.val)
  pathCount, pathSum = 0, 0
  # find the sums of all sub-paths in the current path list
  for i in range(len(currentPath)-1, -1, -1):
    pathSum += currentPath[i]
    # if the sum of any sub-path is equal to 'S' we increment our path count.
    if pathSum == S:
      pathCount += 1

  # traverse the left sub-tree
  pathCount += count_paths_recursive(currentNode.left, S, currentPath)
  # traverse the right sub-tree
  pathCount += count_paths_recursive(currentNode.right, S, currentPath)

  # remove the current node from the path to backtrack
  # we need to remove the current node while we are going up the recursive call stack
  del currentPath[-1]

  return pathCount


def main():
  root = TreeNode(12)
  root.left = TreeNode(7)
  root.right = TreeNode(1)
  root.left.left = TreeNode(4)
  root.right.left = TreeNode(10)
  root.right.right = TreeNode(5)
  print("Tree has paths: " + str(count_paths(root, 11)))


main()

What I wanted was to do DFS but instead of keeping a list I wanted to just keep track of my current node and have a starting node as well that I keep in memory For example,

for the tree...

                 1
                / \    
               7   9
              / \  |\
             6   5 2 3

If I want to find all the paths that add up to 12 and doesn't have to start at root or end at leaf then my plan was to keep track of my initial starting node and the current node I am at.

For the sake of keeping this post short I skip the first path being that neither

1 -> 7
1 -> 7 -> 6
7 -> 6

add up to 12.

Beginning with the next possible path lets start at the root node...

we are at 1 and 1 is less than 12 so we move on to the next node 7 and 1 + 7 is 8 and 8 is less than 12 so we continue to the next node 5 and 8 + 5 is 13.

13 is greater than 12 I immediately know that I need to shrink the path starting from the beginning in this case the node with value 1. When shrinking we need to make sure to subtract from the running sum we have the value of the node we are letting go. So 13 - 1 is 12 and we have found a path that adds up to the target sum. Continue this algorithm over and over again for the rest of the paths

class TreeNode:
  def __init__(self, val, left=None, right=None):
    self.val = val
    self.left = left
    self.right = right


def count_paths(root, S):
  return count_paths_helper(root, S, root, 0, 0)

def count_paths_helper(current_node, S, start, running_sum, num_paths):
  if current_node is None:
    return 0

  running_sum += current_node.val

  # Found a path
  if running_sum == S:
    return 1
  
  # shrink the path starting from the beginning
  if running_sum > S:
    # if the beginning of the path is also equal to the end of the path continue on our path
    # and get rid of the running sum we currently have
    if start == current_node:
      num_paths += count_paths_helper(current_node.left, S, start, running_sum - start.val, num_paths)
      num_paths += count_paths_helper(current_node.right, S, start, running_sum - start.val, num_paths)
    # shrink the path starting from the beginning moving to the left
    if start.left:
      num_paths += count_paths_helper(current_node, S, start.left, running_sum - start.val, num_paths)
    # shrink the path starting from the beginning moving to the right
    if start.right:
      num_paths += count_paths_helper(current_node, S, start.right, running_sum - start.val, num_paths)
  # if we are on a path that has not exceeded the target and not equal to the target sum
  # continue on our path
  else:
    num_paths += count_paths_helper(current_node.left, S, start, running_sum, num_paths)
    num_paths += count_paths_helper(current_node.right, S, start, running_sum, num_paths)

  return num_paths

def main():
  root = TreeNode(12)
  root.left = TreeNode(7)
  root.right = TreeNode(1)
  root.left.left = TreeNode(4)
  root.right.left = TreeNode(10)
  root.right.right = TreeNode(5)
  print("Tree has paths: " + str(count_paths(root, 11)))


main()

This does not work and I just don't understand really how I can implement this.


Solution

  • Some issues:

    • Your idea of an algorithm will only work when the node's values are non-negative.
    • num_paths is accumulated wrongly: you pass it as argument to the recursive call, and that execution will add to that and return the increased result, after which you add it to the caller's num_paths value. This is wrong as it will always add a value to num_paths that is at least as great as num_paths, making it double at each such assignment. You can just omit the function parameter and let each execution of the function start from 0 and return that count.
    • In the if start == current_node: case, you should pass start.left or start.right as third argument instead of start, because you reduce the sum with start.val.
    • if start == current_node: should have a corresponding else block, as currently the rest of the code would still execute when this condition is true, whereby the recursive call will get two nodes that are in the wrong order, representing a sort of "negative" path.

    Here is a correction:

    def count_paths(root, S):
      return count_paths_helper(root, S, root, 0)
    
    def count_paths_helper(current_node, S, start, running_sum):
      if current_node is None:
        return 0
    
      running_sum += current_node.val
      # Found a path
      if running_sum == S:
        return 1
    
      num_paths = 0  
      # shrink the path starting from the beginning
      if running_sum > S:
        # if the beginning of the path is also equal to the end of the path continue on our path
        # and get rid of the running sum we currently have
        if start == current_node:
          num_paths += count_paths_helper(current_node.left, S, start.left, running_sum - start.val)
          num_paths += count_paths_helper(current_node.right, S, start.right, running_sum - start.val)
        else:
          # shrink the path starting from the beginning moving to the left
          num_paths += count_paths_helper(current_node, S, start.left, running_sum - start.val)
          # shrink the path starting from the beginning moving to the right
          num_paths += count_paths_helper(current_node, S, start.right, running_sum - start.val)
      # if we are on a path that has not exceeded the target and not equal to the target sum
      # continue on our path
      else:
        num_paths += count_paths_helper(current_node.left, S, start, running_sum)
        num_paths += count_paths_helper(current_node.right, S, start, running_sum)
    
      return num_paths