Search code examples
algorithmdata-structuresbinary-treebinary-search-tree

How to find kth smallest element in a BST when the tree may be modified frequently?


I'm working the problem LeetCode 230: Kth Smallest Element in a BST. My Python code uses recursive inorder traversal, and although it is not directly relevant to this question, is given below for reference. In the worst case, it takes linear time (O(n)) if we end up visiting all the nodes (k = n).

def kth_smallest(root: Optional[TreeNode], k: int) -> int:
    def dfs(node: TreeNode, count: int) -> tuple[int, int]:
        """
        :param node: Current node
        :param count: Number of nodes visited so far
        :returns: a 2-tuple consisting of the number of nodes visited, and the value of the last visited node.
            If the kth node is found, immediately returns the value of the kth node
        """
        if node.left is not None:
            # count is the number of nodes in the left subtree.
            count, val = dfs(node.left, count)
            if count == k:
                return k, val
        # count the current node.
        count += 1
        if count == k or node.right is None:
            return count, node.val
        return dfs(node.right, count)

    assert root is not None
    return dfs(root, 0)[1]

There is a follow up question that asks:

If the BST is modified often (i.e., we can do insert and delete operations) and you need to find the kth smallest frequently, how would you optimize?

I'm thinking B+ Tree. Find/Insert/Delete operations run in O(log n) time. A range query with k elements occurring within the range requires O(k + log n) time.

But can we find the k-th smallest element in a B+ tree any faster than O(n)? Note that I mention B+ tree here in reference to my comment "I'm thinking B+ tree". Another data structure that fits the problem description, but is not mentioned in the question, would work just fine.


Solution

  • You can do this with B+tree, B-tree, AVL, Red-Black, Skip List, or another efficient sorted container of your choice.

    The idea is to extend nodes with a size attribute so you can greedily decide how to navigate down the structure to the kth value.

    As you started out with a BST, you can turn your BST into an AVL tree (augmenting nodes with a balance factor attribute) and turn it into an order statistic tree (augmenting nodes with a size attribute).

    Contrary to how LeetCode presents their BST, it is good practice to store the root in a separate class so that you can define insert/delete methods on it and support the concept of an empty tree.

    That second class could look like this:

    class AugmentedAvl:
        def __init__(self):
            self.root = None
    
        def insert(self, value):
            self.root = AugmentedAvlNode.insert(self.root, value)
    
        def delete(self, value):
            self.root = AugmentedAvlNode.delete(self.root, value)
    
        def kth_smallest(self, k):
            if self.root and self.root.size >= k:
                return self.root.kth_smallest(k)
        
        def size(self):
            return AugmentedAvlNode.sizeof(self.root)
    

    The augmented Node class would have most of the logic:

    class AugmentedAvlNode:
        def __init__(self, val):
            self.val = val
            self.children = [None, None]  # Can access through left,right getters/setters
            self.bf = 0  # Balance factor should be -1, 0 or 1.
            self.size = 1  # Number of nodes in the tree rooted by this node
            
        @property
        def left(self):
            return self.children[0]
        
        @left.setter
        def left(self, value):
            self.children[0] = value    
    
        @property
        def right(self):
            return self.children[1]
        
        @right.setter
        def right(self, value):
            self.children[1] = value    
    
        @staticmethod
        def sizeof(node):
            return node.size if node else 0
    
        def min(self):
            return self.left.min() if self.left else self.val
    
        def attach(self, side, child):  # Does not take care of balance factors
            self.children[side] = child
            self.size = sum(map(self.sizeof, self.children)) + 1
            return self
    
        def adjust_balance(self, add_bf):
            if self.bf != add_bf: # The balance factor update will stay in range
                self.bf += add_bf
                return self
            # Imbalance would occur. Prepare for rotation:
            side = (1 - self.bf) // 2
            lifting = self.children[1-side]
            if lifting.bf == -self.bf:  # Inner grandchild is heavy: start double rotation
                child, lifting = lifting, lifting.children[side]
                self.bf, child.bf = (+(lifting.bf < 0), -(lifting.bf > 0))[::-self.bf]
                lifting.bf = 0
                lifting.children[1-side] = child.attach(side, lifting.children[1-side])
            else: # Prepare simple rotation:
                self.bf *= not lifting.bf
                lifting.bf = -self.bf
            # Finish rotation
            return lifting.attach(side, self.attach(1-side, lifting.children[side]))
    
        @classmethod
        def insert(cls, node, value):
            if not node:  # Found the spot where to insert
                return cls(value)
            if node.val == value:
                raise ValueError("insert({value}): tree already has this value")
            side = value >= node.val
            child = node.children[side]
            orig_bf = child and child.bf
            node.children[side] = cls.insert(child, value)
            node.size += 1
            if child and (orig_bf or child.bf == 0): # Height didn't change
                return node
            return node.adjust_balance((-1, 1)[side])
    
        @classmethod
        def delete(cls, node, value):
            if not node:
                raise ValueError("delete({value}): value not found in tree")
            if node.val == value:  # Found the node to delete
                if not node.left or not node.right:  # Simple case
                    return node.left or node.right
                value = node.val = node.right.min()  # Delete successor node instead
            side = value >= node.val
            child = node.children[side]
            orig_bf = child and child.bf        
            child = node.children[side] = cls.delete(child, value)
            node.size -= 1
            if child and (child.bf or orig_bf == 0): # Height didn't change
                return node
            return node.adjust_balance((1, -1)[side])
                  
        def kth_smallest(self, k):
            left_size = self.sizeof(self.left)
            if k <= left_size:
                return self.left.kth_smallest(k)
            k -= left_size + 1
            if k == 0:
                return self.val
            return self.right.kth_smallest(k)        
    

    Here is code that tests the above implementation:

    class TestableAvl(AugmentedAvl):
        def inorder(self, node=None):
            node = node or self.root
            if node:
                if node.left:
                    yield from self.inorder(node.left)
                yield node.val
                if node.right:
                    yield from self.inorder(node.right)
        
        def verify(self):
            # inorder sequence must be sorted
            values = list(self.inorder())
            if values != sorted(values):
                raise ValueError(f"Tree does not have a correct inorder sequence {values}")
            self.verify_node(self.root)
    
        def verify_node(self, node):
            # balance factors and sizes must be correct
            if not node:
                return -1, 0  # height, size
            left_height, left_size = self.verify_node(node.left)
            right_height, right_size = self.verify_node(node.right) 
            height = max(left_height, right_height) + 1
            bf = right_height - left_height
            size = 1 + left_size + right_size
            if node.size != size:
                ValueError("Size inconsistent at node {self.val}")
            if node.bf != bf or abs(bf) > 1:
                ValueError("Balance factor inconsistent at node {self.val}")
            return height, size
    
    
    from random import shuffle
    
    def main():
        for i in range(30):
            print(f"Running random test {i}")
            lst = list(range(130))
            shuffle(lst)
            tree = TestableAvl()
            for value in lst:
                tree.insert(value)
                tree.verify()
            if tree.size() != len(lst):
                raise ValueError("Tree does not have the expected size")
            for k in lst:
                if tree.kth_smallest(k + 1) != k:
                    raise ValueError("kth_smallest({k}) returns wrong value")
            shuffle(lst)
            for value in lst:
                tree.delete(value)
                tree.verify()
        print("done")
    
    main()
    

    Other solutions

    There are of course libraries that can offer this functionality. There is SortedList, part of SortedContainers, which implements __getitem__, so you can get the kth element with lst[k-1] syntax.

    See also: Is there a module for balanced binary tree in Python's standard library?