86

I want to choose a random item from a set, but the chance of choosing any item should be proportional to the associated weight

Example inputs:

item                weight
----                ------
sword of misery         10
shield of happy          5
potion of dying          6
triple-edged sword       1

So, if I have 4 possible items, the chance of getting any one item without weights would be 1 in 4.

In this case, a user should be 10 times more likely to get the sword of misery than the triple-edged sword.

How do I make a weighted random selection in Java?

Peter Lawrey
  • 525,659
  • 79
  • 751
  • 1,130
yosi
  • 887
  • 1
  • 8
  • 5

6 Answers6

134

I would use a NavigableMap

public class RandomCollection<E> {
    private final NavigableMap<Double, E> map = new TreeMap<Double, E>();
    private final Random random;
    private double total = 0;

    public RandomCollection() {
        this(new Random());
    }

    public RandomCollection(Random random) {
        this.random = random;
    }

    public RandomCollection<E> add(double weight, E result) {
        if (weight <= 0) return this;
        total += weight;
        map.put(total, result);
        return this;
    }

    public E next() {
        double value = random.nextDouble() * total;
        return map.higherEntry(value).getValue();
    }
}

Say I have a list of animals dog, cat, horse with probabilities as 40%, 35%, 25% respectively

RandomCollection<String> rc = new RandomCollection<>()
                              .add(40, "dog").add(35, "cat").add(25, "horse");

for (int i = 0; i < 10; i++) {
    System.out.println(rc.next());
} 
Peter Lawrey
  • 525,659
  • 79
  • 751
  • 1,130
  • @cpu_meltdown The cost of log(n) ;) – Peter Lawrey Feb 27 '19 at 20:25
  • 3
    Thanks for the answer Peter! It works well. If anyone – like me – was wondering about the `if (weight <= 0 return this;`, it serves an important purpose. Without it, if you reuse the example and call `.add(0, "lizard")` *after* the call to `.add(25, "horse")`, this will overwrite the entry for "horse" in the map due to a new call to `put(total, result)` with the *same* total weight as the previous entry, and therefore replace "horse" with "lizard" even though it should have had a 0% chance of being selected. – Nicolas Favre-Felix Jul 08 '20 at 19:27
  • @PeterLawrey Thanks, how would you update this if we wanted to support random removal in addition to random choose? – shmth Apr 02 '21 at 05:24
  • Does it need to add up to 100%? – CiY3 Jan 20 '22 at 15:19
  • 1
    @CiY3 no, you can add any weight you like. the var `total ` will note how much is 100% – Ninja Jul 16 '22 at 10:42
54

There is now a class for this in Apache Commons: EnumeratedDistribution

Item selectedItem = new EnumeratedDistribution<>(itemWeights).sample();

where itemWeights is a List<Pair<Item, Double>>, like (assuming Item interface in Arne's answer):

final List<Pair<Item, Double>> itemWeights = Collections.newArrayList();
for (Item i: itemSet) {
    itemWeights.add(new Pair(i, i.getWeight()));
}

or in Java 8:

itemSet.stream().map(i -> new Pair(i, i.getWeight())).collect(toList());

Note: Pair here needs to be org.apache.commons.math3.util.Pair, not org.apache.commons.lang3.tuple.Pair.

Holt
  • 36,600
  • 7
  • 92
  • 139
kdkeck
  • 2,097
  • 1
  • 14
  • 10
  • 2
    This should really be higher in the list of answers... Why reinvent the wheel? Furthermore, `EnumeratedDistribution` allows the selection of multiple samples at once, which is pretty neat. – Holt Oct 24 '19 at 13:02
  • 2
    Commons Math3 is now unsupported. The functionality of `EnumeratedDistribution` has been moved to [DiscreteProbabilityCollectionSampler](https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSampler.html) in the [Commons RNG](https://commons.apache.org/proper/commons-rng/) library. – user1585916 Jun 10 '21 at 20:03
30

You will not find a framework for this kind of problem, as the requested functionality is nothing more then a simple function. Do something like this:

interface Item {
    double getWeight();
}

class RandomItemChooser {
    public Item chooseOnWeight(List<Item> items) {
        double completeWeight = 0.0;
        for (Item item : items)
            completeWeight += item.getWeight();
        double r = Math.random() * completeWeight;
        double countWeight = 0.0;
        for (Item item : items) {
            countWeight += item.getWeight();
            if (countWeight >= r)
                return item;
        }
        throw new RuntimeException("Should never be shown.");
    }
}
Steven Behnke
  • 3,336
  • 3
  • 26
  • 34
Arne Deutsch
  • 14,629
  • 5
  • 53
  • 72
  • which order should be used for the items list? from higher to smaller? Thanks. – aloplop85 Aug 14 '15 at 19:27
  • 1
    You do not need to sort the list. Order does not matter. – Arne Deutsch Aug 16 '15 at 11:10
  • The order of the items in the list doesn't matter, because the value of `r` is a uniformly distributed random number, which means that the probability `r` be a certain value is equal to all other values `r` may be. Thus the items in the list are not "favored" and it doesn't matter where they are in the list. – matthaeus Feb 05 '16 at 14:49
  • 1
    If you use `countWeight >= r`, an item with a weight of zero can be selected if it happens to be the first item and r = 0, which can happen. – Andrei Volgin Dec 30 '18 at 01:09
9

Use an alias method

If you're gonna roll a lot of times (as in a game), you should use an alias method.

The code below is rather long implementation of such an alias method, indeed. But this is because of the initialization part. The retrieval of elements is very fast (see the next and the applyAsInt methods they don't loop).

Usage

Set<Item> items = ... ;
ToDoubleFunction<Item> weighter = ... ;

Random random = new Random();

RandomSelector<T> selector = RandomSelector.weighted(items, weighter);
Item drop = selector.next(random);

Implementation

This implementation:

  • uses Java 8;
  • is designed to be as fast as possible (well, at least, I tried to do so using micro-benchmarking);
  • is totally thread-safe (keep one Random in each thread for maximum performance, use ThreadLocalRandom?);
  • fetches elements in O(1), unlike what you mostly find on the internet or on StackOverflow, where naive implementations run in O(n) or O(log(n));
  • keeps the items independant from their weight, so an item can be assigned various weights in different contexts.

Anyways, here's the code. (Note that I maintain an up to date version of this class.)

import static java.util.Objects.requireNonNull;

import java.util.*;
import java.util.function.*;

public final class RandomSelector<T> {

  public static <T> RandomSelector<T> weighted(Set<T> elements, ToDoubleFunction<? super T> weighter)
      throws IllegalArgumentException {
    requireNonNull(elements, "elements must not be null");
    requireNonNull(weighter, "weighter must not be null");
    if (elements.isEmpty()) { throw new IllegalArgumentException("elements must not be empty"); }

    // Array is faster than anything. Use that.
    int size = elements.size();
    T[] elementArray = elements.toArray((T[]) new Object[size]);

    double totalWeight = 0d;
    double[] discreteProbabilities = new double[size];

    // Retrieve the probabilities
    for (int i = 0; i < size; i++) {
      double weight = weighter.applyAsDouble(elementArray[i]);
      if (weight < 0.0d) { throw new IllegalArgumentException("weighter may not return a negative number"); }
      discreteProbabilities[i] = weight;
      totalWeight += weight;
    }
    if (totalWeight == 0.0d) { throw new IllegalArgumentException("the total weight of elements must be greater than 0"); }

    // Normalize the probabilities
    for (int i = 0; i < size; i++) {
      discreteProbabilities[i] /= totalWeight;
    }
    return new RandomSelector<>(elementArray, new RandomWeightedSelection(discreteProbabilities));
  }

  private final T[] elements;
  private final ToIntFunction<Random> selection;

  private RandomSelector(T[] elements, ToIntFunction<Random> selection) {
    this.elements = elements;
    this.selection = selection;
  }

  public T next(Random random) {
    return elements[selection.applyAsInt(random)];
  }

  private static class RandomWeightedSelection implements ToIntFunction<Random> {
    // Alias method implementation O(1)
    // using Vose's algorithm to initialize O(n)

    private final double[] probabilities;
    private final int[] alias;

    RandomWeightedSelection(double[] probabilities) {
      int size = probabilities.length;

      double average = 1.0d / size;
      int[] small = new int[size];
      int smallSize = 0;
      int[] large = new int[size];
      int largeSize = 0;

      // Describe a column as either small (below average) or large (above average).
      for (int i = 0; i < size; i++) {
        if (probabilities[i] < average) {
          small[smallSize++] = i;
        } else {
          large[largeSize++] = i;
        }
      }

      // For each column, saturate a small probability to average with a large probability.
      while (largeSize != 0 && smallSize != 0) {
        int less = small[--smallSize];
        int more = large[--largeSize];
        probabilities[less] = probabilities[less] * size;
        alias[less] = more;
        probabilities[more] += probabilities[less] - average;
        if (probabilities[more] < average) {
          small[smallSize++] = more;
        } else {
          large[largeSize++] = more;
        }
      }

      // Flush unused columns.
      while (smallSize != 0) {
        probabilities[small[--smallSize]] = 1.0d;
      }
      while (largeSize != 0) {
        probabilities[large[--largeSize]] = 1.0d;
      }
    }

    @Override public int applyAsInt(Random random) {
      // Call random once to decide which column will be used.
      int column = random.nextInt(probabilities.length);

      // Call random a second time to decide which will be used: the column or the alias.
      if (random.nextDouble() < probabilities[column]) {
        return column;
      } else {
        return alias[column];
      }
    }
  }
}
Olivier Grégoire
  • 33,839
  • 23
  • 96
  • 137
5
public class RandomCollection<E> {
  private final NavigableMap<Double, E> map = new TreeMap<Double, E>();
  private double total = 0;

  public void add(double weight, E result) {
    if (weight <= 0 || map.containsValue(result))
      return;
    total += weight;
    map.put(total, result);
  }

  public E next() {
    double value = ThreadLocalRandom.current().nextDouble() * total;
    return map.ceilingEntry(value).getValue();
  }
}
ronen
  • 92
  • 1
  • 4
0

A simple (even naive?), but (as I believe) straightforward method:

/**
* Draws an integer between a given range (excluding the upper limit).
* <p>
* Simulates Python's randint method.
* 
* @param min: the smallest value to be drawed.
* @param max: the biggest value to be drawed.
* @return The value drawn.
*/
public static int randomInt(int min, int max)
    {return (int) (min + Math.random()*max);}

/**
 * Tests wether a given matrix has all its inner vectors
 * has the same passed and expected lenght.
 * @param matrix: the matrix from which the vectors length will be measured.
 * @param expectedLenght: the length each vector should have.
 * @return false if at least one vector has a different length.
 */
public static boolean haveAllVectorsEqualLength(int[][] matrix, int expectedLenght){
    for(int[] vector: matrix){if (vector.length != expectedLenght) {return false;}}
    return true;
}

/**
* Draws an integer between a given range
* by weighted values.
* 
* @param ticketBlock: matrix with limits and weights for the drawing. All its
* vectors should have lenght two. The weights, instead of percentages, should be
* measured as integers, according to how rare each one should be draw, the rarest
* receiving the smallest value.
* @return The value drawn.
*/
public static int weightedRandomInt(int[][] ticketBlock) throws RuntimeException {
    boolean theVectorsHaventAllLengthTwo = !(haveAllVectorsEqualLength(ticketBlock, 2));
    if (theVectorsHaventAllLengthTwo)
        {throw new RuntimeException("The given matrix has, at least, one vector with length lower or higher than two.");}
    // Need to test for duplicates or null values in ticketBlock!
    
    // Raffle urn building:
    int raffleUrnSize = 0, urnIndex = 0, blockIndex = 0, repetitionCount = 0;
    for(int[] ticket: ticketBlock){raffleUrnSize += ticket[1];}
    int[] raffleUrn = new int[raffleUrnSize];
    
    // Raffle urn filling:
    while (urnIndex < raffleUrn.length){
        do {
            raffleUrn[urnIndex] = ticketBlock[blockIndex][0];
            urnIndex++; repetitionCount++;
        } while (repetitionCount < ticketBlock[blockIndex][1]);
        repetitionCount = 0; blockIndex++;
    }
    
    return raffleUrn[randomInt(0, raffleUrn.length)];
}