Search code examples
javaarraysrecurrenceinversion

Finding all inversions of an array using Merger-Sort


I'm trying to implement the popular algorithm for finding all inversions of an array using merge sort but it keeps outputting the wrong answer, it counts far too many inversions - I believe part or all of the sub arrays are being iterated too many times in the recurrence calls? I can't quite put my finger on it - I would appreciate some pointers as to why this might be happening. Please see my implementation in java below:

public class inversionsEfficient {

  public int mergeSort(int[] list, int[] temp, int left, int right) {
    int count = 0;
    int mid = 0;

    if(right > left) {
      mid = (right+left)/2;
      count += mergeSort(list, temp, left, mid);
      count += mergeSort(list, temp, mid+1, right);
      count += merge(list, temp, left, mid+1, right);
    }

    return count;
  }

  public int merge(int[] list, int[] temp, int left, int mid, int right) {
    int count = 0;
    int i = left;
    int j = mid;
    int k = left;

    while((i<=mid-1) && (j<=right)) {
      if(list[i] <= list[j]) {
        temp[k] = list[i];
        k += 1;
        i += 1;
      }
      else {
        temp[k] = list[j];
        k += 1;
        j += 1;
        count += mid-1;
      }
    }

    while(i<=mid-1) {
      temp[k] = list[i];
      k += 1;
      i += 1;
    }

    while(j<=right) {
      temp[k] = list[j];
      k += 1;
      j += 1;
    }

    for(i=left;i<=right;i++) {
      list[i] = temp[i];
    }

    return count;
  }

  public static void main(String[] args) {
    int[] myList = {5, 3, 76, 12, 89, 22, 5};
    int[] temp = new int[myList.length];

    inversionsEfficient inversions = new inversionsEfficient();
    System.out.println(inversions.mergeSort(myList, temp, 0, myList.length-1));
  }
}

This algorithm is based on this pseudocode from Introduction to Algorithms by Cormen: [1]: https://i.sstatic.net/ea9No.png


Solution

  • Instead of -

    count += mid - 1;

    try -

    count += mid - i;

    The whole solution becomes as shown below :-

    public class inversionsEfficient {
    
        public int mergeSort(int[] list, int[] temp, int left, int right) {
            int count = 0;
            int mid = 0;
    
            if (right > left) {
                mid = (right + left) / 2;
                count += mergeSort(list, temp, left, mid);
                count += mergeSort(list, temp, mid + 1, right);
                count += merge(list, temp, left, mid + 1, right);
            }
    
            return count;
        }
    
        public int merge(int[] list, int[] temp, int left, int mid, int right) {
            int count = 0;
            int i = left;
            int j = mid;
            int k = left;
    
            while ((i <= mid - 1) && (j <= right)) {
                if (list[i] <= list[j]) {
                    temp[k] = list[i];
                    k += 1;
                    i += 1;
                } else {
                    temp[k] = list[j];
                    k += 1;
                    j += 1;
                    count += mid - i; // (mid - i), not (mid - 1)
                }
            }
    
            while (i <= mid - 1) {
                temp[k] = list[i];
                k += 1;
                i += 1;
            }
    
            while (j <= right) {
                temp[k] = list[j];
                k += 1;
                j += 1;
            }
    
            for (i = left; i <= right; i++) {
                list[i] = temp[i];
            }
    
            return count;
        }
    
        public static void main(String[] args) {
            int[] arr = {5, 3, 76, 12, 89, 22, 5};
            int[] temp = new int[arr.length];
    
            inversionsEfficient inversions = new inversionsEfficient();
            System.out.println(inversions.mergeSort(arr, temp, 0, arr.length - 1));
        }
    
    }
    

    The output generated by the above code for the example array mentioned in the question is 8, which is correct because there are 8 inversions in the array [5, 3, 76, 12, 89, 22, 5] -

     1. (5, 3)
     2. (76, 12)
     3. (76, 22)
     4. (76, 5)
     5. (12, 5)
     6. (89, 22)
     7. (89, 5)
     8. (22, 5)
    

    Explanation for Code Change

    This algorithm counts the number of inversions required as the sum of the number of inversions in the left sub-array + number of inversions in the right sub-array + number of inversions in the merge process.

    If list[i] > list[j], then there are (mid – i) inversions, because the left and right subarrays are sorted. This implies that all the remaining elements in left-subarray (list[i+1], list[i+2] … list[mid]) will also be greater than list[j].

    For a more detailed explanation, have a look at the GeeksForGeeks article on Counting Inversions.