0

I am trying to implement the median of medians algorithm in Java. The algorithm shall determine the median of a set of numbers. I tried to implement the pseudo code on wikipedia:

https://en.wikipedia.org/wiki/Median_of_medians

I am getting a buffer overflow and don't know why. Due to the recursions it's quite difficult to keep track of the code for me.

    import java.util.Arrays;

public class MedianSelector {
    private static final int CHUNK = 5;
    
    public static void main(String[] args) {
        int[] test = {9,8,7,6,5,4,3,2,1,0,13,11,10};
        lowerMedian(test);
        System.out.print(Arrays.toString(test));
    }
    
    /**
     * Computes and retrieves the lower median of the given array of
     * numbers using the Median algorithm presented in the lecture.
     * 
     * @param input numbers.
     * @return the lower median.
     * @throw IllegalArgumentException if the array is {@code null} or empty.
    */
    public static int lowerMedian(int[] numbers) {
        if(numbers == null || numbers.length == 0) {
            throw new IllegalArgumentException();
        }
        
        return numbers[select(numbers, 0, numbers.length - 1, (numbers.length - 1) / 2)];
    }
    
    private static int select(int[] numbers, int left, int right, int i) {
        
        if(left == right) {
            return left;
        }
        
        int pivotIndex = pivot(numbers, left, right);
        pivotIndex = partition(numbers, left, right, pivotIndex, i);
        
        if(i == pivotIndex) {
            return i;
        }else if(i < pivotIndex) {
            return select(numbers, left, pivotIndex - 1, i); 
        }else {
            return select(numbers, left, pivotIndex + 1, i);
        }
    }
    
    private static int pivot(int numbers[], int left, int right) {
        if(right - left < CHUNK) {
            return partition5(numbers, left, right);
        }
        
        for(int i=left; i<=right; i=i+CHUNK) {
            int subRight = i + (CHUNK-1);
            
            if(subRight > right) {
                subRight = right;
            }
            
            int medChunk = partition5(numbers, i, subRight);
                    
            int tmp = numbers[medChunk];
            numbers[medChunk] = numbers[(int) (left + Math.floor((double) (i-left)/CHUNK))];
            numbers[(int) (left + Math.floor((double) (i-left)/CHUNK))] = tmp;
        }
        
        int mid = (right - left) / 10 + left +1;
        return select(numbers, left, (int) (left + Math.floor((right - left) / CHUNK)), mid);
    }
    
    private static int partition(int[] numbers, int left, int right, int idx, int k) {
        int pivotVal = numbers[idx];
        int storeIndex = left;
        int storeIndexEq = 0;
        int tmp = 0;
        
        tmp = numbers[idx];
        numbers[idx] = numbers[right];
        numbers[right] = tmp;
        
        for(int i=left; i<right; i++) {
            if(numbers[i] < pivotVal) {
                tmp = numbers[i];
                numbers[i] = numbers[storeIndex];
                numbers[storeIndex] = tmp;
                storeIndex++;
            }
        }
        
        storeIndexEq = storeIndex;
        
        for(int i=storeIndex; i<right; i++) {
            if(numbers[i] == pivotVal) {
                tmp = numbers[i];
                numbers[i] = numbers[storeIndexEq];
                numbers[storeIndexEq] = tmp;
                storeIndexEq++;
            }
        }
        
        tmp = numbers[right];
        numbers[right] = numbers[storeIndexEq];
        numbers[storeIndexEq] = tmp;
        
        if(k < storeIndex) {
            return storeIndex;
        }
        
        if(k <= storeIndexEq) {
            return k;
        }
           
        return storeIndexEq;
    }
    
    //Insertion sort
    private static int partition5(int[] numbers, int left, int right) {
        int i = left + 1;
        int j = 0;
        
        while(i<=right) {
            j= i;
            while(j>left && numbers[j-1] > numbers[j]) {
                int tmp = numbers[j-1];
                numbers[j-1] = numbers[j];
                numbers[j] = tmp;
                j=j-1;
            }
            i++;
        }
        
        return left + (right - left) / 2;
    }
}

Confirm n (in the pseudo code) or i (in my code) stand for the position of the median? So lets assume our array is number = {9,8,7,6,5,4,3,2,1,0}. I would call select{numbers, 0, 9,4), correct?

I don't understand the calculation of mid in pivot? Why is there a division by 10? Maybe there is a mistake in the pseudo code?

Thanks for your help.

user73347
  • 47
  • 3

1 Answers1

0

EDIT: It turns out the switch from iteration to recursion was a red herring. The actual issue, identified by the OP, was in the arguments to the 2nd recursive select call.

This line:

return select(numbers, left, pivotIndex + 1, i);

should be

return select(numbers, pivotIndex + 1, right, i);

I'll leave the original answer below as I don't want to appear to be clever than I actually was.


I think you may have misinterpreted the pseudocode for the select method - it uses iteration rather than recursion.

Here's your current implementation:

private static int select(int[] numbers, int left, int right, int i) {
    
    if(left == right) {
        return left;
    }
    
    int pivotIndex = pivot(numbers, left, right);
    pivotIndex = partition(numbers, left, right, pivotIndex, i);
    
    if(i == pivotIndex) {
        return i;
    }else if(i < pivotIndex) {
        return select(numbers, left, pivotIndex - 1, i); 
    }else {
        return select(numbers, left, pivotIndex + 1, i);
    }
}

And the pseudocode

function select(list, left, right, n)
    loop
        if left = right then
            return left
        pivotIndex := pivot(list, left, right)
        pivotIndex := partition(list, left, right, pivotIndex, n)
        if n = pivotIndex then
            return n
        else if n < pivotIndex then
            right := pivotIndex - 1
        else
            left := pivotIndex + 1

This would typically be implemented using a while loop:

  private static int select(int[] numbers, int left, int right, int i) {
      while(true)
      {
          if(left == right) {
              return left;
          }
          
          int pivotIndex = pivot(numbers, left, right);
          pivotIndex = partition(numbers, left, right, pivotIndex, i);
          
          if(i == pivotIndex) {
              return i;
          }else if(i < pivotIndex) {
              right = pivotIndex - 1; 
          }else {
              left = pivotIndex + 1;
          }
      }
  }

With this change your code appears to work, though obviously you'll need to test to confirm.

int[] test = {9,8,7,6,5,4,3,2,1,0,13,11,10};
System.out.println("Lower Median: " + lowerMedian(test));

int[] check = test.clone();
Arrays.sort(check);
System.out.println(Arrays.toString(check));

Output:

Lower Median: 6
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13]
RaffleBuffle
  • 5,396
  • 1
  • 9
  • 16
  • That wasn't quite the issue. Recurssion works as well. The arguments of select() for left and right were wrong. You brought me on the right way. Your answer works as well. Thanks. – user73347 Jun 10 '21 at 15:15
  • Ah right, good catch. I've updated the answer with this information. Also, I'd have no issue if you decided to delete the question, as there aren't really any generally applicable issues involved. – RaffleBuffle Jun 10 '21 at 15:53