Search code examples
pythonalgorithmbinary-treedepth-first-search

confused about largest diameter of binary tree problem


Not sure how this works...

class TreeNode:
  def __init__(self, val, left=None, right=None):
    self.val = val
    self.left, self.right = left, right

  def find_diameter(self, root):
    self.calculate_height(root)
    return self.treeDiameter

  def calculate_height(self, currentNode):
    if currentNode is None:
      return 0

    leftTreeDiameter = self.calculate_height(currentNode.left)
    rightTreeDiameter = self.calculate_height(currentNode.right)

    diameter = leftTreeDiameter + rightTreeDiameter + 1
    self.treeDiameter = max(self.treeDiameter, diameter)

    return max(leftTreeDiameter, rightTreeDiameter) + 1

The above code works to get the max diameter of a binary tree but I don't understand the last line in calculate_height. Why do we need to return max(leftTreeDiameter, rightTreeDiameter) + 1

I obviously don't understand it but what I do know is that for each currentNode we are going to keep going down the left side of the tree and similarly then do the same for the right. If we ended up with no node (meaning right before we were at a leaf node) then we return 0 as we don't want to add 1 for a node that does not exist.

The only place that seems to be adding anything besides 0 is the last line of code in calculate_height because although we are adding leftTreeDiameter + rightTreeDiameter + 1 to get the total diameter this is only possible because of the return 0 and return max(leftTreeDiameter, rightTreeDiameter) + 1 correct?

Also, I am confused as to why leftTreeDiameter can be assigned self.calculate_height(currentNode.left). What I mean is that I thought I would need something like...

def calculate_left_height(self, currentNode, height=0):
  if currentNode is None:
    return 0

  self.calculate_height(currentNode.left, height + 1)
  
  return height

where we just add 1 to the height each time. In this case instead of doing something like leftTreeDiameter += self.calculate_height(currentNode.left) I just pass in as an argument height + 1 each time we see a node.

but if I do this I would need a separate method just to calculate the right height as well and in my find_diameter method would need to recursively call find_diameter with both root.left and also with root.right.

Where is my logic wrong and how is it that calculate_height actually works. I guess I am having trouble trying to figure out how to keep track of the stack?


Solution

  • The names used in this code are confusing: leftTreeDiameter and rightTreeDiameter are not diameters, but heights.

    Secondly, the function calculate_height has side effects, which is not very nice. On the one hand it returns a height, and simultaneously it assigns a diameter. This is confusing. Many Python coders would prefer a function to be pure and just return something, without altering anything else. Or, alternatively, a function could only alter some state and not return it. Doing both can be confusing.

    Also, it is confusing that although the class is called TreeNode, its find_diameter method still requires a node as argument. This is counter-intuitive. We would expect the method to take self as the node to act on, not the argument.

    But let's just rename the variables and add some comments:

    leftHeight = self.calculate_height(currentNode.left)
    rightHeight = self.calculate_height(currentNode.right)
    
    # What is the size of the longest path from leaf-to-leaf 
    #   whose top node is the current node?
    diameter = leftHeight + rightHeight + 1
    # Is this path longer than the longest path that we
    #   had found so far? If so, take this one.
    self.treeDiameter = max(self.treeDiameter, diameter)
    # The height of the tree rooted at the current node
    #   is the height of the highest childtree (either left or right), 
    #   with one added to account for the current node
    return max(leftHeight, rightHeight) + 1
    

    It should be clear, but do realise that self in this process is always the instance on which the find_diameter method is called, and does not really play a role as actual node, as the root is passed as argument. So the repeated assignment to self.treeDiameter is always to the same one property. This property is not created on every node... just on the node on which you invoke find_diameter.

    I hope the inserted comments have clarified how this algorithm works.

    NB: your own idea on creating calculate_left_height is not going to do it: it never alters the value of height that it receives as argument, and ends up returning it. So it returns the same value it already receives. That is obviously not going to do much...