Search code examples
javaalgorithmmedian

Finding index of kth smallest element in an array efficiently (iterative)?


I would like to find the kth smallest element in an array but actually need its index for my partition method.

I found this code on this blog for finding the kth smallest element: http://blog.teamleadnet.com/2012/07/quick-select-algorithm-find-kth-element.html

But this only returns the value, not the index.

Do you have any idea how I can find the index of it efficiently?


Solution

  • The simplest way is to create an additional indices array of the same length, fill it by numbers from 0 to length-1 and when arr array is changed, perform the same change with the indices array. Finally return the corresponding entry from the indices array. You don't even have to understand the original algorithm to do this. Here's the modified method (my changes are marked with ***):

    public static int selectKthIndex(int[] arr, int k) {
        if (arr == null || arr.length <= k)
            throw new IllegalArgumentException();
    
        int from = 0, to = arr.length - 1;
    
        // ***ADDED: create and fill indices array
        int[] indices = new int[arr.length];
        for (int i = 0; i < indices.length; i++)
            indices[i] = i;
    
        // if from == to we reached the kth element
        while (from < to) {
            int r = from, w = to;
            int mid = arr[(r + w) / 2];
    
            // stop if the reader and writer meets
            while (r < w) {
    
                if (arr[r] >= mid) { // put the large values at the end
                    int tmp = arr[w];
                    arr[w] = arr[r];
                    arr[r] = tmp;
                    // *** ADDED: here's the only place where arr is changed
                    // change indices array in the same way
                    tmp = indices[w];
                    indices[w] = indices[r];
                    indices[r] = tmp;
                    w--;
                } else { // the value is smaller than the pivot, skip
                    r++;
                }
            }
    
            // if we stepped up (r++) we need to step one down
            if (arr[r] > mid)
                r--;
    
            // the r pointer is on the end of the first k elements
            if (k <= r) {
                to = r;
            } else {
                from = r + 1;
            }
        }
    
        // *** CHANGED: return indices[k] instead of arr[k]
        return indices[k];
    }
    

    Note that this method modifies the original arr array. If you don't like this, add arr = arr.clone() at the beginning of the method.