1

Question Summary

Is there any way of updating the probabilities within an existing instance of the class EnumeratedIntegerDistribution without creating an entirely new instance?

Background

I'm trying to implement a simplified Q-learning style demonstration using an android phone. I need to update the probabilities for each item with each loop through the learning process. Currently I am unable to find any method accessible from my instance of enumeratedIntegerDistribution that will let me reset|update|modify these probabilities. Therefore, the only way I can see to do this is to create a new instance of EnumeratedIntegerDistribution within each loop. Keeping in mind that each of these loops is only 20ms long, it is my understanding that this would be terribly memory inefficient compared to creating one instance and updating the values within the existing instance. Is there no standard set-style methods to update these probabilities? If not, is there a recommended workaround (i.e. using a different class, making my own class, overriding something to make it accessible, etc.?)

A follow up would be whether or not this question is a moot effort. Would the compiled code actually be any more/less efficient by trying to avoid this new instance every loop? (I'm not knowledgeable enough to know how compilers would handle such things).

Code

A minimal example below:

package com.example.mypackage.learning;  
  
import android.app.Activity;  
import android.os.Bundle;  
import org.apache.commons.math3.distribution.EnumeratedIntegerDistribution;  
  
  
public class Qlearning extends Activity {  
  
    private int selectedAction;  
    private int[] actions = {0, 1, 2};  
    private double[] weights = {1.0, 1.0, 1.0};  
    private double[] qValues = {1.0, 1.0, 1.0};  
    private double qValuesSum;  
    EnumeratedIntegerDistribution enumeratedIntegerDistribution = new EnumeratedIntegerDistribution(actions, weights);  
    private final double alpha = 0.001;  
    int action;  
    double reward;  
  
    @Override  
    protected void onCreate(Bundle savedInstanceState) {  
        super.onCreate(savedInstanceState);  
        while(true){  
            action = determineAction();  
            reward = determineReward();  
            learn(action, reward);  
        }  
    }  
      
    public void learn(int action, double reward) {  
        qValues[selectedAction] = (alpha * reward) + ((1.0 - alpha) * qValues[selectedAction]);  
        qValuesSum = 0;  
        for (int i = 0; i < qValues.length; i++){  
            qValuesSum += Math.exp(qValues[i]);  
        }  
        weights[selectedAction] = Math.exp(qValues[selectedAction]) / qValuesSum;  
        // *** This seems inefficient ***  
        EnumeratedIntegerDistribution enumeratedIntegerDistribution = new EnumeratedIntegerDistribution(actions, weights);  
    }  
}

Please don't focus on the absence of the determineAction() or determineReward() methods, as this is simply a minimal example. You could easily just sub in fixed values there (e.g. 1, and 1.5) if you wanted a working example.

Also, I'm well aware of the infinite while loop that would be troublesome for a GUI, but again, just trying to reduce the code I have to show here to get the point across.

Edit:

In response to a comment I'm posting what I had for a similar class below. Note I haven't used this in over a year and things may be broken. Just posting for reference:

public class ActionDistribution{

    private double reward = 0;
    private double[] weights = {0.34, 0.34, 0.34};
    private double[] qValues = {0.1, 0.1, 0.1};
    private double learningRate = 0.1;
    private double temperature = 1.0;
    private int selectedAction;

    public ActionDistribution(){}

    public ActionDistribution(double[] weights, double[] qValues, double learningRate, double temperature){
        this.weights = weights;
        this.qValues = qValues;
        this.learningRate = learningRate;
        this.temperature = temperature;
    }

    public int actionSelect(){

        double sumOfWeights = 0;
        for (double weight: weights){
            sumOfWeights = sumOfWeights + weight;
        }
        double randNum = Math.random() * sumOfWeights;
        double selector = 0;
        int iterator = -1;

        while (selector < randNum){
            try {
                iterator++;
                selector = selector + weights[iterator];
            }catch (ArrayIndexOutOfBoundsException e){
                Log.e("abcvlib", "weight index bound exceeded. randNum was greater than the sum of all weights. This can happen if the sum of all weights is less than 1.");
            }
        }
        // Assigning this as a read-only value to pass between threads.
        this.selectedAction = iterator;

        // represents the action to be selected
        return iterator;
    }

    public double[] getWeights(){
        return weights;
    }

    public double[] getqValues(){
        return qValues;
    }

    public double getQValue(int action){
        return qValues[action];
    }

    public double getTemperature(){
        return temperature;
    }


    public int getSelectedAction() {
        return selectedAction;
    }

    public void setWeights(double[] weights) {
        this.weights = weights;
    }

    public void setQValue(int action, double qValue) {
        this.qValues[action] = qValue;
    }

    public void updateValues(double reward, int action){

        double qValuePrev = getQValue(action);

        // update qValues due to current reward
        setQValue(action,(learningRate * reward) + ((1.0 - learningRate) * qValuePrev));
        // update weights from new qValues

        double qValuesSum = 0;
        for (double qValue : getqValues()) {
            qValuesSum += Math.exp(temperature * qValue);
        }

        // update weights
        for (int i = 0; i < getWeights().length; i++){
            getWeights()[i] = Math.exp(temperature * getqValues()[i]) / qValuesSum;
        }
    }

    public double getReward() {
        return reward;
    }

    public void setReward(double reward) {
        this.reward = reward;
    }
}
topher217
  • 1,188
  • 12
  • 35
  • No, I don't think so – Severin Pappadeux Nov 12 '19 at 01:39
  • 1
    I ended up making my own class to work around this, but I was hoping to find something a bit more established (constructors, error handling, type restrictions, etc.) I'll post my basic class here later if no better answers come along. Maybe Java just doesn't have any good libraries for Q-learning or reinforcement learning? – topher217 Nov 12 '19 at 03:17
  • @topher217 would you still have the implementation you wrote? – Ishaan Dec 01 '20 at 01:18
  • @Ishaan, I just updated my question to include what I was able to find a year later :D ... no promises, but hopefully it can give you some hints towards something that works for you. – topher217 Dec 01 '20 at 15:00
  • @topher217 amazing, thank you! – Ishaan Dec 01 '20 at 17:24

1 Answers1

1

Unfortunately it is not possible to update the existing EnumeratedIntegerDistribution. I have had similar issue in the past and I ended up re-creating the instance everytime I need to update the chances.

I won't worry too much about the memory allocations as those will be short-lived objects. These are micro-optimisations you should not worry about.

In my project I did implement a cleaner way with interfaces to create instances of these EnumeratedDistribution class.

This is not the direct answer but might guide you in the right direction.

public class DistributedProbabilityGeneratorBuilder<T extends DistributedProbabilityGeneratorBuilder.ProbableItem> {

    private static final DistributedProbabilityGenerator EMPTY = () -> {
        throw new UnsupportedOperationException("Not supported");
    };

    private final Map<Integer, T> distribution = new HashMap<>();

    private DistributedProbabilityGeneratorBuilder() {
    }

    public static <T extends ProbableItem> DistributedProbabilityGeneratorBuilder<T> newBuilder() {
        return new DistributedProbabilityGeneratorBuilder<>();
    }

    public DistributedProbabilityGenerator build() {
        return build(ProbableItem::getChances);
    }

    /**
     * Returns a new instance of probability generator at every call.
     * @param chanceChangeFunction - Function to modify existing chances
     */
    public DistributedProbabilityGenerator build(Function<T, Double> chanceChangeFunction) {
        if (distribution.isEmpty()) {
            return EMPTY;
        } else {
            return new NonEmptyProbabilityGenerator(createPairList(chanceChangeFunction));
        }
    }

    private List<Pair<Integer, Double>> createPairList(Function<T, Double> chanceChangeFunction) {
        return distribution.entrySet().stream()
                .map(entry -> Pair.create(entry.getKey(), chanceChangeFunction.apply(entry.getValue())))
                .collect(Collectors.toList());
    }

    public DistributedProbabilityGeneratorBuilder<T> add(int id, T item) {
        if (distribution.containsKey(id)) {
            throw new IllegalArgumentException("Id " + id + " already present.");
        }

        this.distribution.put(id, item);
        return this;
    }

    public interface ProbableItem {

        double getChances();
    }

    public interface DistributedProbabilityGenerator {

        int generateId();
    }

    public static class NonEmptyProbabilityGenerator implements DistributedProbabilityGenerator {

        private final EnumeratedDistribution<Integer> enumeratedDistribution;

        NonEmptyProbabilityGenerator(List<Pair<Integer, Double>> pairs) {
            this.enumeratedDistribution = new EnumeratedDistribution<>(pairs);
        }

        @Override
        public int generateId() {
            return enumeratedDistribution.sample();
        }
    }

    public static ProbableItem ofDouble(double chances) {
        return () -> chances;
    }
}

Note - I am using EnumeratedDistribution<Integer>. You can easily change it to be EnumuratedIntegerDistribution.

The way I use the above class is as follows.

DistributedProbabilityGenerator distributedProbabilityGenerator = DistributedProbabilityGeneratorBuilder.newBuilder()
                .add(0, ofDouble(10))
                .add(1, ofDouble(45))
                .add(2, ofDouble(45))
                .build();

int generatedObjectId = distributedProbabilityGenerator.generateId();

Again, this is not a direct answer to your question but more of a pointer towards how you can use these classes in a better way.

Sneh
  • 3,527
  • 2
  • 19
  • 37