You could try using a navigable map with the probability distribution. Unlike normal Maps the NaviableMap defines an absolute ordering over its keys. And if the key isn't present in the map it can tell you which is the closest key, or which is the smallest key that is greater than the argument. I've used ceilingEntry
which returns the map entry with the smallest key that is greater than or equal to the given key.
If you use a TreeMap as your implementation of NavigableMap then look ups on distributions with many classes will be faster as it performs a binary search rather than starting with the first key and then testing each key in turn.
The other advantage of NaviableMap is that you get the class of data your directly interested in rather than an index to another array or list, which can make code cleaner.
In my example I've used BigDecimals as I'm not particularly fond of using floating point numbers as you can't specify the precision you need. But you could use floats or doubles or whatever.
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.Arrays;
import java.util.NavigableMap;
import java.util.TreeMap;
public class Main {
public static void main(String[] args) {
String[] classes = {"A", "B", "C", "D"};
BigDecimal[] probabilities = createProbabilities(classes.length);
BigDecimal[] distribution = createDistribution(probabilities);
System.out.println("probabilities: "+Arrays.toString(probabilities));
System.out.println("distribution: "+Arrays.toString(distribution)+"\n");
NavigableMap<BigDecimal, String> map = new TreeMap<BigDecimal, String>();
for (int i = 0; i < distribution.length; i++) {
map.put(distribution[i], classes[i]);
}
BigDecimal d = new BigDecimal(Math.random());
System.out.println("probability: "+d);
System.out.println("result: "+map.ceilingEntry(d).getValue());
}
private static BigDecimal[] createDistribution(BigDecimal[] probabilities) {
BigDecimal[] distribution = new BigDecimal[probabilities.length];
distribution[0] = probabilities[0];
for (int i = 1; i < distribution.length; i++) {
distribution[i] = distribution[i-1].add(probabilities[i]);
}
return distribution;
}
private static BigDecimal[] createProbabilities(int n) {
BigDecimal[] probabilities = new BigDecimal[n];
for (int i = 0; i < probabilities.length; i++) {
probabilities[i] = F(i+1, n);
}
return probabilities;
}
private static BigDecimal F(int i, int n) {
// 6i(n-i) / (n3 - n)
BigDecimal j = new BigDecimal(i);
BigDecimal m = new BigDecimal(n);
BigDecimal six = new BigDecimal(6);
BigDecimal dividend = m.subtract(j).multiply(j).multiply(six);
BigDecimal divisor = m.pow(3).subtract(m);
return dividend.divide(divisor, 64, RoundingMode.HALF_UP);
}
}