1

I have a function that returns a numerical value for an instance, I later use this numerical value classify the instance into one of three categories. The categories are relatively separable, see figure below (The three colors represent the three different classes). Histogram over occurrences of different classes

So here I would like two thresholds, k1 and k2 so that everything left of k1 is classifier as red, everything right of k2 is classified blue, and everything in the middle is classified as green.

I started with a modified version of Kadane's algorithm based off this solution. Where I first sorted all (color, value) tuples by their value, and then generated an array where all green classifications were given value 1, and non greens are -1. So I would get an array that looked something like this:

 [-1, -1, -1, -1, 1, -1, -1, ..., 1, 1, 1, -1, 1, ..., -1, -1, -1, -1]

That is, initially there are lots of -1 (The reds), the around the middle there are a lot of greens, and towards the end its mostly blues. Now, by running the Kadane's algorithm, will I get the optimal split?

Here's the code I tested with:

import java.util.*;

public class Kadanes {
    private static Color[] correctClasses = new Color[]{Color.RED, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.RED, Color.RED, Color.RED, Color.BLUE, Color.BLUE, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.BLUE, Color.BLUE, Color.GREEN, Color.GREEN, Color.RED, Color.RED, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.BLUE, Color.GREEN, Color.BLUE, Color.GREEN, Color.BLUE, Color.GREEN, Color.RED, Color.BLUE, Color.GREEN, Color.BLUE, Color.BLUE, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.RED, Color.BLUE, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.RED, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.BLUE, Color.GREEN, Color.RED, Color.BLUE, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.RED, Color.BLUE, Color.BLUE, Color.RED, Color.GREEN, Color.BLUE, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.BLUE, Color.RED, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.BLUE, Color.GREEN, Color.GREEN, Color.RED, Color.BLUE, Color.BLUE, Color.RED, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.BLUE, Color.RED, Color.BLUE, Color.GREEN, Color.RED, Color.RED, Color.BLUE, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.RED, Color.BLUE, Color.BLUE, Color.GREEN, Color.BLUE, Color.RED, Color.GREEN, Color.RED, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.RED, Color.GREEN, Color.RED, Color.GREEN, Color.RED, Color.BLUE, Color.GREEN, Color.GREEN, Color.BLUE, Color.RED, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.RED, Color.GREEN, Color.GREEN, Color.GREEN, Color.RED, Color.GREEN, Color.GREEN, Color.GREEN, Color.RED, Color.BLUE, Color.GREEN, Color.RED, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.RED, Color.GREEN, Color.RED, Color.RED, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.BLUE, Color.GREEN, Color.GREEN, Color.RED, Color.BLUE, Color.BLUE, Color.GREEN, Color.RED, Color.RED, Color.BLUE, Color.RED, Color.GREEN, Color.RED, Color.RED, Color.RED, Color.RED, Color.RED, Color.GREEN, Color.BLUE, Color.GREEN, Color.RED, Color.GREEN, Color.GREEN, Color.GREEN, Color.RED, Color.RED, Color.GREEN, Color.GREEN, Color.GREEN, Color.RED, Color.GREEN, Color.RED, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.RED, Color.BLUE, Color.RED, Color.RED, Color.GREEN, Color.RED, Color.GREEN, Color.BLUE, Color.RED, Color.RED, Color.GREEN, Color.BLUE, Color.RED, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.RED, Color.RED, Color.RED, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.RED, Color.BLUE, Color.BLUE, Color.GREEN, Color.RED, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.BLUE, Color.BLUE, Color.GREEN, Color.BLUE, Color.BLUE, Color.GREEN, Color.GREEN, Color.RED, Color.BLUE, Color.BLUE, Color.RED, Color.BLUE, Color.BLUE, Color.GREEN, Color.BLUE, Color.RED, Color.GREEN, Color.BLUE, Color.GREEN, Color.RED, Color.GREEN, Color.BLUE, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.RED, Color.GREEN, Color.BLUE, Color.RED, Color.GREEN, Color.BLUE, Color.BLUE, Color.RED, Color.BLUE, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.BLUE, Color.RED, Color.GREEN, Color.BLUE, Color.BLUE, Color.GREEN, Color.GREEN, Color.RED, Color.GREEN, Color.BLUE, Color.GREEN, Color.RED, Color.GREEN, Color.RED, Color.BLUE, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.RED, Color.GREEN, Color.BLUE, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.GREEN, Color.RED, Color.GREEN, Color.GREEN, Color.BLUE, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.RED, Color.GREEN, Color.GREEN, Color.RED, Color.BLUE, Color.BLUE, Color.BLUE, Color.RED, Color.BLUE, Color.GREEN, Color.BLUE, Color.GREEN, Color.BLUE, Color.GREEN, Color.BLUE, Color.RED, Color.BLUE, Color.BLUE, Color.RED, Color.RED, Color.BLUE, Color.BLUE, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.GREEN, Color.BLUE, Color.GREEN, Color.BLUE, Color.BLUE, Color.BLUE, Color.RED, Color.BLUE, Color.BLUE, Color.GREEN, Color.GREEN, Color.RED, Color.GREEN, Color.GREEN, Color.BLUE, Color.BLUE, Color.BLUE, Color.BLUE, Color.RED, Color.BLUE, Color.BLUE, Color.BLUE, Color.BLUE, Color.RED, Color.BLUE, Color.RED, Color.BLUE, Color.BLUE, Color.BLUE, Color.BLUE, Color.GREEN, Color.BLUE, Color.BLUE, Color.BLUE, Color.BLUE, Color.RED, Color.RED, Color.RED, Color.BLUE, Color.BLUE, Color.BLUE, Color.BLUE, Color.GREEN};
    private static double[] predictedValues = new double[]{0.0, 0.34, 2.0, 2.67, 7.53, -0.04, 2.0, -3.55, 3.78, 0.33, 3.0, -0.21, 1.41, -0.37, 0.84, 3.94, 8.34, 0.0, -1.39, 3.0, -1.63, 0.0, 3.0, 1.26, 0.0, 0.0, 0.0, 0.0, 0.61, 0.0, 3.34, 0.57, -1.05, 0.63, 0.0, 0.71, 0.0, 2.34, -0.41, -1.77, 3.0, 0.62, 0.93, 1.55, 2.0, 8.0, -1.55, 5.75, 0.0, 0.0, -0.25, 0.0, 1.0, 10.51, 0.0, 0.47, 0.78, -1.08, -1.51, 1.0, 1.0, 0.0, 4.33, -0.6, 0.37, 6.0, 1.16, -4.07, 2.0, 0.91, -0.05, 1.78, 0.0, 0.0, 0.0, 0.0, 0.0, 1.64, 1.55, 4.44, 2.78, 1.47, 3.75, 0.0, 7.59, 0.0, 0.94, 2.46, -0.23, -0.2, 0.0, 0.39, -2.31, 3.0, -1.15, 2.0, -0.76, -1.33, 0.0, 0.61, 0.77, -1.77, -1.08, 0.0, -3.2, 3.46, 1.0, 0.0, 0.0, 3.33, 0.0, 0.0, 2.81, 0.0, 0.0, 0.0, 3.0, 0.0, -0.88, 1.65, -1.09, -0.35, 0.0, 0.0, 5.0, 0.0, 2.88, -0.72, 0.87, 7.0, 7.48, -1.98, 1.0, 1.11, 4.0, 1.53, 0.0, 8.07, 1.54, 4.23, 0.0, -0.73, 6.61, 0.07, 0.0, -4.32, -1.77, 2.05, -1.08, 4.3, 1.61, 2.96, 3.0, 0.0, 3.66, 0.0, 0.0, 0.05, -0.77, -1.0, 0.0, 5.43, 2.12, -1.55, 2.3, 0.0, 3.6, 0.0, 0.0, -10.21, 2.0, 0.55, -0.63, 0.0, 1.0, 0.0, 0.0, 1.28, 3.0, 0.0, 0.44, 1.27, 2.12, 2.17, 1.76, -1.9, 5.42, 1.0, 3.76, -3.55, -0.82, 0.0, 0.11, -1.7, -0.33, 0.0, 0.0, -2.01, 0.0, 3.52, 2.0, 6.0, 0.92, 7.22, 0.0, 0.0, 0.0, 0.0, 0.36, -1.77, 0.0, -3.32, -0.91, 2.69, -0.86, -0.27, 3.28, -1.02, 0.41, -0.6, 2.61, 0.0, 0.36, 0.0, 0.91, 0.0, -2.82, 0.0, -1.77, 0.0, -0.33, 3.94, -2.55, 8.0, 3.29, 2.7, -4.4, 9.0, 0.0, 2.81, -0.23, -2.51, 2.0, -0.19, 0.0, 0.0, 0.0, 0.8, 8.33, 0.0, 0.59, 0.0, 0.41, 0.0, 0.8, 1.7, 3.27, 0.0, 0.34, -1.83, 0.0, -1.0, 0.29, 3.71, -0.44, -0.59, 1.25, 2.3, -1.56, 0.0, 6.21, -0.68, 0.0, 0.0, -0.3, 0.0, 1.0, 0.86, 0.0, 0.0, 0.0, 0.0, 0.41, 1.91, -0.17, -0.77, 1.0, 3.0, 2.0, 3.0, -0.71, 0.0, 0.62, 0.0, 2.54, 1.14, 0.0, 0.0, 3.27, 0.0, 0.96, -0.33, 0.0, 0.0, 1.91, -0.2, 0.0, 0.0, 0.6, 0.0, -0.82, 1.0, -0.54, 6.52, -2.48, 2.0, 0.0, 0.0, 1.61, 0.0, 0.0, 0.0, -0.17, 0.0, 1.0, -5.36, 2.73, 0.0, 7.97, 3.67, 0.0, -0.88, 0.93, 0.0, 3.0, -1.03, -0.64, 2.78, 0.0, 1.0, 3.0, 0.0, 0.46, 0.0, -0.63, 0.0, 4.0, 4.0, 1.61, 0.0, 0.0, 1.07, 0.0, 1.0, 18.39, -1.82, 0.0, 0.86, -0.42, -1.77, -0.61, 0.0, 0.68, -3.13, 0.53, 0.0, 3.0, 0.0, 2.47, 0.0, -1.74, 5.31, 0.0, 0.3, 0.0, 0.0, 4.0, 1.0, 0.64, 1.0, 0.0, -1.77, 3.31, -1.77, -0.43, -3.55, 0.94, 8.59, 0.0, 1.81, 3.69, -1.77, -0.32, 0.0, 3.0, 1.93, -1.47, 1.0, 3.21, 0.0, 0.0, 0.0, 0.33, 0.0, 0.0, -0.39, 0.0, 1.0, 0.0, 1.98, 0.0, 0.0, 7.45, 0.72, 0.34, 0.0, 0.35, 0.0, -2.74, 0.28, 4.0, 3.0, -0.91, -4.43, 0.0, 2.28, 3.0, -2.5, -2.66, 2.0, -0.66, 3.0, 11.06, 1.43, 3.0, 0.0, -0.79, 6.3, 0.94, 3.92, -4.43, 5.14, -2.35, 8.83, 1.04, 2.6, 5.0, 3.72};

    private static List<Tuple> previousResults = new ArrayList<>();
    static {
        for(int i=0; i<correctClasses.length; i++) {
            previousResults.add(new Tuple(correctClasses[i], predictedValues[i]));
        }
    }


    public static void main(String[] args) {
        double[] exampleThresholds = new double[]{-1.65, 1.65};
        double[] thresholds = getThreshold();
        System.out.println(Arrays.toString(thresholds));

        System.out.println("Example threshold accuracy: " + getAccuracy(exampleThresholds));
        System.out.println("Optimal threshold accuracy: " + getAccuracy(thresholds));
    }


    private static double[] getThreshold() {
        Collections.sort(previousResults, Collections.reverseOrder());

        int max_so_far = 0;
        int max_ending_here = 0;
        int max_start_index = 0;
        int startIndex = 0;
        int max_end_index = -1;

        for(int i = 0; i < previousResults.size(); i++) {
            int currentElementScore = (previousResults.get(i).correct == Color.GREEN ? 1 : -1);
            if(max_ending_here + currentElementScore < 0) {
                startIndex = i+1;
                max_ending_here = 0;
            } else {
                max_ending_here += currentElementScore;
            }

            if(max_ending_here > max_so_far) {
                max_so_far = max_ending_here;
                max_start_index = startIndex;
                max_end_index = i;
            }
        }

        double lowThreshold = getAvgValue(max_start_index-1, max_start_index);
        double highThreshold = getAvgValue(max_end_index, max_end_index+1);

        return new double[]{lowThreshold, highThreshold};
    }


    private static double getAccuracy(double[] thresholds) {
        int numCorrectlyClassified = 0;
        for(int i=0; i<correctClasses.length; i++) {
            Color predictedClassification = classify(predictedValues[i], thresholds[0], thresholds[1]);
            if(predictedClassification == correctClasses[i]) {
                numCorrectlyClassified++;
            }
        }

        return (double) numCorrectlyClassified / correctClasses.length;
    }

    private static Color classify(double value, double lowThresh, double highThresh) {
        if(value < lowThresh) return Color.RED;
        if(value > highThresh) return Color.BLUE;
        return Color.GREEN;
    }


    private static double getAvgValue(int index1, int index2) {
        if(index1 < 0) {
            return Double.NEGATIVE_INFINITY;
        } else if (index2 >= previousResults.size()) {
            return Double.POSITIVE_INFINITY;
        }

        return (previousResults.get(index1).predicted + previousResults.get(index2).predicted) / 2;
    }


    static class Tuple implements Comparable<Tuple> {
        private Color correct;
        private double predicted;

        Tuple(Color correct, double predicted) {
            this.correct = correct;
            this.predicted = predicted;
        }

        public String toString() {
            return "[" + correct.name() + ", " + predicted + "]";
        }

        @Override
        public int compareTo(Tuple o) {
            double diff = o.predicted - predicted;
            return diff != 0 ? (int) Math.signum(diff) : correct.compareTo(o.correct);
        }
    }

    enum Color {
        BLUE, GREEN, RED
    }
}

The output I get is:

[0.0, 0.0]
Example threshold accuracy: 0.5602678571428571
Optimal threshold accuracy: 0.49107142857142855

So the optimal threshold it finds is just within the 0.0 range, and I entered just a quick example threshold that performs much better. Is the implementation wrong or is it not possible to use Kadane's algorithm to solve this simple problem, if not, which algorithm could I use?

Community
  • 1
  • 1
Limon
  • 963
  • 2
  • 10
  • 23

1 Answers1

0

YOu can't use Kadane algorithm for this problem because it optimizes the number of greens taken. Say you have something like:

1, 1, -1, 1, -1, 1,.., -1, 1, 1, 1

The algorithm will like to maximize the sum and take the first to ones and the last 3, since the sum over the middle portion is just -1.

So instead I would use a N logN algorithm to find the thresholds. First sort the array. Precompute partial count arrays with the number of green, red and blues between 0 and x (for all x). Iterate with the first threshold through the sorted array and binary search the best position for the second one, using the accuracy as the guiding metric. To compute accuracy you can use the precomputed partial counts. You need to see which way the metric increases.

There can be a few corner cases. If you can afford a N^2 just try all the threshold values and use the precomputed arrays to speed up the evaluation.

Sorin
  • 11,863
  • 22
  • 26
  • The reason I thought Kadane's algorithm would work is because the ranges are somewhat separable, there is very little overlap between red and blue, so wouldn't optimize number of greens taken also optimize the others as well? I don't quite understand your second paragraph, Kadane's picks a single range, so even though the sum of middle part is 0, the starting part and ending part are >0, so it'll include the entire range? – Limon Mar 03 '16 at 16:43
  • 1
    @Limon Yes, my point was that because the middle part is 0, the solution will include the entire range, which is likely not what you want. – Sorin Mar 04 '16 at 15:54