Search code examples
javasortingpartitioningpseudocode

Finding the kth smallest element in an unsorted array


I'm trying to implement the following pseudocode. I need to do this using logical partitions only.

Procedure SELECT( k,S) 
{ if  |S| =1 then return the single element in S
   else  { choose an element a randomly from S;
          let S1,S2,and S3 be he sequences of elements in S   
          less than, equal to, and greater than m, respectively;
         if |S1| >=k then return SELECT(k,S1)
          else 
               if (|S1| + |S2| >=k then return m
               else  return SELECT(k-|S1|-|S2| , S3);
         }
}

Here is my attempt at it so far:

public static int select(int k, int[] s, int arrayLeft, int arrayRight) {
    if (s.length == 1) {
        return s[0];
    } else {
        Random rand = new Random();
        int right = rand.nextInt(arrayRight) + arrayLeft;
        int m = s[right];
        int pivot = partition(s, arrayLeft, right); // pivot = |s1|
        if (pivot >= k) {
            return select(k, s, arrayLeft, pivot - 1);
        } else {
            // Calculate |s2|
            int s2Length = 0;
            for (int i = pivot; s[i] == m; i++) {
                s2Length++;
            }
            if (pivot + s2Length >= k) {
                return m;
            } else {
                int s3Left = pivot + s2Length;
                return select(k - pivot - s2Length, s, s3Left + 1, s.length);
            }
        }
    }
}

// all elements smaller than m are to the left of it,
// all elements greater than m are to the right of it
private static int partition(int[] s, int left, int right) {
    int m = s[right];
    int i = left;
    for (int j = left; j <= right - 1; j++) {
        if (s[j] <= m) {
            swap(s, i, j);
            i++;
        }
    }
    swap(s, i, right);
    return i;
}

private static void swap(int[] s, int i, int j) {
    int temp = s[i];
    s[i] = s[j];
    s[j] = temp;
}

My select method isn't returning the actual kth smallest element. The partition method only does its job properly on the elements smaller than m. On the portion of the array to the right of m, there are elements of any value. How do I fix this? All of the solutions I've seen online appear the same as my method. Any help is appreciated!


Solution

  • I am not sure of the details of how your code was supposed to work, but I think I have spotted a few suspicious points.

    First I think you should be precise about the valid arguments for you method and how it uses arrayLeft and arrayRight. Write a Javadoc comment and state this. It will make it much easier for yourself and anyone else to argue about what is correct and incorrect in the code.

    This is wrong:

        if (s.length == 1) {
    

    You are passing the same array through all your recursive calls, so if it didn’t have length 1 from the outset (a trivial case), it will never have length 1. Instead use arrayLeft and arrayRight to determine the number of elements to consider.

    This line does not look right:

            int right = rand.nextInt(arrayRight) + arrayLeft;
    

    If arrayLeft is 10 and arrayRight 12, it may yield up to 21. I did observe an ArrayIndexOutOfBoundsException in the following line once because right pointed outside the array.

    The comment in this line is incorrect and may lead you to wrong arguments about the code:

            int pivot = partition(s, arrayLeft, right); // pivot = |s1|
    

    The pivot returned from partition() is the index to m after reordering. I think the correct statement is pivot == arrayLeft + |s1|. Please check yourself.

    I further believe that you should not pass right as the last argument in the above call, but arrayRight. This error may be the cause for your observation that partition() leaves any values to the right of m.

    You may risk an ArrayIndexOutOfBoundsException here too:

                for (int i = pivot; s[i] == m; i++) {
    

    You should add an additional condition like i <= arrayRight or i < s.length.

    Finally, this looks wrong in my eyes:

                    return select(k - pivot - s2Length, s, s3Left + 1, s.length);
    

    I am thinking:

                    return select(k - pivot - s2Length, s, s3Left, arrayRight);
    

    But please check with your own knowledge. I am particularly in doubt about arrayRight.