Search code examples
javabinary-search-treeranking

Nth largest node in Binary Search Tree


I'm trying to find the Nth largest node in a Binary Search Tree given a number. All other solutions online find the Nth smallest node such as this one:

/**
 * Return the key in the symbol table whose rank is {@code k}.
 * This is the (k+1)st smallest key in the symbol table.
 *
 * @param  k the order statistic
 * @return the key in the symbol table of rank {@code k}
 * @throws IllegalArgumentException unless {@code k} is between 0 and
 *        <em>n</em>–1
 */
public Key select(int k) {
    if (k < 0 || k >= size()) {
        throw new IllegalArgumentException("argument to select() is invalid: " + k);
    }
    Node x = select(root, k);
    return x.key;
}

// Return key of rank k. 
private Node select(Node x, int k) {
    if (x == null) return null; 
    int t = size(x.left); 
    if      (t > k) return select(x.left,  k); 
    else if (t < k) return select(x.right, k-t-1); 
    else            return x; 
} 

Source: https://algs4.cs.princeton.edu/32bst/BST.java.html

How would I convert the select(Node x, int k) method to find the Nth largest node?

For example, in a BST that looks like:

       30
     /    \
    20    35
   / \    / \
 15   25 31 40

The largest node has a key of 40.

The Ranked BST would look like:

        4
     /    \
    6      2
   / \    / \
  7   5  3   1

Solution

  • One thing to note about this BST is that the rank starts from 0.

    A simpler way

    For a BST containing X elements numbered 0 to (X-1),

    The Nth smallest element is equivalent to the (X-N)th largest element, and vice versa.

    If you have no choice but to change the method

    What select does in this case is something like a binary search on the rank. So if we adjust it such that it always goes towards the right for smaller ranks (and left for higher ranks), we can make it to return the answer we want.

    Invert the x.right and x.left:

    private Node select(Node x, int k) {
        if (x == null) return null; 
        int t = size(x.right); 
        if      (t > k) return select(x.right,  k); 
        else if (t < k) return select(x.left, k-t-1); 
        else            return x; 
    }