Search code examples
javaalgorithmquicksortquickselect

Quickselect implementation not working


I am trying to write code to determine the n smallest item in an array. It's sad that I am struggling with this. Based on the algorithm from my college textbook from back in the day, this looks to be correct. However, obviously I am doing something wrong as it gives me a stack overflow exception.

My approach is:

  1. Set the pivot to be at start + (end-start) / 2 (rather than start+end/2 to prevent overflow)
  2. Use the integer at this location to be the pivot that I compare everything to
  3. Iterate and swap everything around this pivot so things are sorted (sorted relative to the pivot)
  4. If n == pivot, then I think I am done
  5. Otherwise, if I want the 4 smallest element and pivot is 3, for example, then I need to look on the right side (or left side if I wanted the 2nd smallest element).

-

public static void main(String[] args) {
    int[] elements = {30, 50, 20, 10};
    quickSelect(elements, 3);
}

private static int quickSelect(int[] elements2, int k) {
    return quickSelect(elements2, k, 0, elements2.length - 1);
}

private static int quickSelect(int[] elements, int k, int start, int end) {
    int pivot = start + (end - start) / 2;
    int midpoint = elements[pivot];
    int i = start, j = end;

    while (i < j) {
        while (elements[i] < midpoint) {
            i++;
        }

        while (elements[j] > midpoint) {
            j--;
        }

        if (i <= j) {
            int temp = elements[i];
            elements[i] = elements[j];
            elements[j] = temp;
            i++;
            j--;
        }
    }
    // Guessing something's wrong here
    if (k == pivot) {
        System.out.println(elements[pivot]);
        return pivot;
    } else if (k < pivot) {
        return quickSelect(elements, k, start, pivot - 1);
    } else {
        return quickSelect(elements, k, pivot + 1, end);
    }
}

Edit: Please at least bother commenting why if you're going to downvote a valid question.


Solution

  • Alright so the first thing I did was rework how I get my pivot/partition point. The shortcoming, as T. Claverie pointed out, is that the pivot I am using isn't technically the pivot since the element's position changes during the partitioning phase.

    I actually rewrote the partitioning code into its own method as below. This is slightly different.

    I choose the first element (at start) as the pivot, and I create a "section" in front of this with items less than this pivot. Then, I swap the pivot's value with the last item in the section of values < the pivot. I return that final index as the point of the pivot.

    This can be cleaned up more (create separate swap method).

    private static int getPivot(int[] elements, int start, int end) {
        int pivot = start;
        int lessThan = start;
    
        for (int i = start; i <= end; i++) {
            int currentElement = elements[i];
            if (currentElement < elements[pivot]) {
                lessThan++;
                int tmp = elements[lessThan];
                elements[lessThan] = elements[i];
                elements[i] = tmp;
            }
        }
        int tmp = elements[lessThan];
        elements[lessThan] = elements[pivot];
        elements[pivot] = tmp;
    
        return lessThan;
    }
    

    Here's the routine that's calls this:

    private static int quickSelect(int[] elements, int k, int start, int end) {
    
        int pivot = getPivot(elements, start, end);
    
        if (k == (pivot - start + 1)) {
            System.out.println(elements[pivot]);
            return pivot;
        } else if (k < (pivot - start + 1)) {
            return quickSelect(elements, k, start, pivot - 1);
        } else {
            return quickSelect(elements, k - (pivot - start + 1), pivot + 1, end);
        }
    }