1

I'm trying to implement the popular algorithm for finding all inversions of an array using merge sort but it keeps outputting the wrong answer, it counts far too many inversions - I believe part or all of the sub arrays are being iterated too many times in the recurrence calls? I can't quite put my finger on it - I would appreciate some pointers as to why this might be happening. Please see my implementation in java below:

public class inversionsEfficient {

  public int mergeSort(int[] list, int[] temp, int left, int right) {
    int count = 0;
    int mid = 0;

    if(right > left) {
      mid = (right+left)/2;
      count += mergeSort(list, temp, left, mid);
      count += mergeSort(list, temp, mid+1, right);
      count += merge(list, temp, left, mid+1, right);
    }

    return count;
  }

  public int merge(int[] list, int[] temp, int left, int mid, int right) {
    int count = 0;
    int i = left;
    int j = mid;
    int k = left;

    while((i<=mid-1) && (j<=right)) {
      if(list[i] <= list[j]) {
        temp[k] = list[i];
        k += 1;
        i += 1;
      }
      else {
        temp[k] = list[j];
        k += 1;
        j += 1;
        count += mid-1;
      }
    }

    while(i<=mid-1) {
      temp[k] = list[i];
      k += 1;
      i += 1;
    }

    while(j<=right) {
      temp[k] = list[j];
      k += 1;
      j += 1;
    }

    for(i=left;i<=right;i++) {
      list[i] = temp[i];
    }

    return count;
  }

  public static void main(String[] args) {
    int[] myList = {5, 3, 76, 12, 89, 22, 5};
    int[] temp = new int[myList.length];

    inversionsEfficient inversions = new inversionsEfficient();
    System.out.println(inversions.mergeSort(myList, temp, 0, myList.length-1));
  }
}

This algorithm is based on this pseudocode from Introduction to Algorithms by Cormen: [1]: https://i.stack.imgur.com/ea9No.png

milkdose
  • 15
  • 4

1 Answers1

1

Instead of -

count += mid - 1;

try -

count += mid - i;

The whole solution becomes as shown below :-

public class inversionsEfficient {

    public int mergeSort(int[] list, int[] temp, int left, int right) {
        int count = 0;
        int mid = 0;

        if (right > left) {
            mid = (right + left) / 2;
            count += mergeSort(list, temp, left, mid);
            count += mergeSort(list, temp, mid + 1, right);
            count += merge(list, temp, left, mid + 1, right);
        }

        return count;
    }

    public int merge(int[] list, int[] temp, int left, int mid, int right) {
        int count = 0;
        int i = left;
        int j = mid;
        int k = left;

        while ((i <= mid - 1) && (j <= right)) {
            if (list[i] <= list[j]) {
                temp[k] = list[i];
                k += 1;
                i += 1;
            } else {
                temp[k] = list[j];
                k += 1;
                j += 1;
                count += mid - i; // (mid - i), not (mid - 1)
            }
        }

        while (i <= mid - 1) {
            temp[k] = list[i];
            k += 1;
            i += 1;
        }

        while (j <= right) {
            temp[k] = list[j];
            k += 1;
            j += 1;
        }

        for (i = left; i <= right; i++) {
            list[i] = temp[i];
        }

        return count;
    }

    public static void main(String[] args) {
        int[] arr = {5, 3, 76, 12, 89, 22, 5};
        int[] temp = new int[arr.length];

        inversionsEfficient inversions = new inversionsEfficient();
        System.out.println(inversions.mergeSort(arr, temp, 0, arr.length - 1));
    }

}

The output generated by the above code for the example array mentioned in the question is 8, which is correct because there are 8 inversions in the array [5, 3, 76, 12, 89, 22, 5] -

 1. (5, 3)
 2. (76, 12)
 3. (76, 22)
 4. (76, 5)
 5. (12, 5)
 6. (89, 22)
 7. (89, 5)
 8. (22, 5)

Explanation for Code Change

This algorithm counts the number of inversions required as the sum of the number of inversions in the left sub-array + number of inversions in the right sub-array + number of inversions in the merge process.

If list[i] > list[j], then there are (mid – i) inversions, because the left and right subarrays are sorted. This implies that all the remaining elements in left-subarray (list[i+1], list[i+2] … list[mid]) will also be greater than list[j].

For a more detailed explanation, have a look at the GeeksForGeeks article on Counting Inversions.

Dhruv Saraswat
  • 858
  • 9
  • 13