Search code examples
pythonalgorithmtime-complexitybinary-tree

Time complexity of this solution for counting nodes in a complete binary tree


I'm working on LeetCode problem 222. Count Complete Tree Nodes:

Given the root of a complete binary tree, return the number of the nodes in the tree.

According to Wikipedia, every level, except possibly the last, is completely filled in a complete binary tree, and all nodes in the last level are as far left as possible. It can have between 1 and 2 nodes inclusive at the last level ℎ.

Design an algorithm that runs in less than O(𝑛) time complexity.

This is my solution.

def countNodes(self, root: Optional[TreeNode]) -> int:
    if not root: return 0
    def maxDepth(root):
        # this finds the maximum depth of the tree 
        # it goes down just the left side of the tree, where nodes are guaranteed
        # O(d) - depth of left-side
        if not root:
            return -1

        return 1 + maxDepth(root.left)
    

    d = maxDepth(root)
    def dfs(root, depth):
        if not root:
            return 0
        if depth == d:
            # if we've reached the bottom nodes, return 1 (counting itself)
            return 1
        
        left, right = 0, 0
        left = dfs(root.left, depth + 1)
        if left == ((2 ** (d - depth)) // 2):
            # determine if we must go to the right (halve the problem)
            # depends on how many bottom nodes are returned by the left recursion
            # for instance, if the depth of the tree is 4 (start from 0), and 
            # the current depth is 2, then this subtree should have 4 bottom nodes
            # if it doesn't, the parent does not have to recurse rightward
            right = dfs(root.right, depth + 1)
        
        return left + right

    bottom = dfs(root, 0)
    top = (2 ** (d)) - 1
    return top + bottom

I'm unsure if the decision in the dfs(root, depth) function reduces the time complexity. I want to say it's still O(n) time complexity in the worst case of the bottom-level being full, so dfs(root.right, depth + 1) would be called each time.

Thanks for the help!


Solution

  • Your solution visits each bottom-layer leaf to add 1 to the count (by return 1). As this is the only way the count can increase and the number of nodes at the bottom layer is O(𝑛) in the worst case, the algorithm does not have a better worst time complexity than O(𝑛).

    Admittedly, the algorithm will avoid a search in the right subtree when it finds that the left subtree is not a perfect tree (its bottom layer is not completely filled), and that makes the best case time complexity O(log𝑛), but the average and worst time complexities remain O(𝑛).

    To get a sublinear time complexity you should implement a binary search. The idea is to number the (potential) bottom-layer nodes from 0 to 2−1. The bit representation of such a leaf identifier also represents the path to that (potential) leaf, where 0 means "left" and 1 means "right". The binary search algorithm can just treat these numbers as normal integers and calculate a middle value from the current two extremes: this will then represent a potential leaf (and a path to it). The decision to narrow the search window to the left or to the right is made by checking whether that middle path ends in a real node or not.

    As these paths each need O(log𝑛) time to traverse, and the binary search needs to check log𝑛 of those paths, the overal time complexity becomes O(log²𝑛).

    Here is an implementation:

    class Solution:
        def countNodes(self, root: Optional[TreeNode]) -> int:
            if not root or not root.left: # base cases
                return int(bool(root))
    
            # Get height of the tree
            node = root
            for height in range(10000):
                node = node.left
                if not node:
                    break
    
            # Given a bitpath (0=left,1=right), tell if bottom layer has a node there
            def isvalidpath(bitpath):
                node = root
                for bit in range(height - 1, -1, -1):
                    node = node.right if (bitpath >> bit) & 1 else node.left
                return node != None
    
            # Do a binary search using bit representations of the paths 
            #    to bottom-layer leaves
            left = 1 # bitpath to a potential second node in bottom layer
            right = 1 << height  # maximum number of leaves in bottom layer 
            while left < right:
                mid = (left + right) >> 1
                if isvalidpath(mid):
                    left = mid + 1
                else:
                    right = mid
            # Now left == right and represents the number of nodes in the bottom layer.
            # The number of nodes in the other levels is 2^h - 1.
            return left + (1 << height) - 1