Question Summary
Is there any way of updating the probabilities within an existing instance of the class EnumeratedIntegerDistribution without creating an entirely new instance?
Background
I'm trying to implement a simplified Q-learning style demonstration using an android phone. I need to update the probabilities for each item with each loop through the learning process. Currently I am unable to find any method accessible from my instance of enumeratedIntegerDistribution
that will let me reset|update|modify these probabilities. Therefore, the only way I can see to do this is to create a new instance of EnumeratedIntegerDistribution within each loop. Keeping in mind that each of these loops is only 20ms long, it is my understanding that this would be terribly memory inefficient compared to creating one instance and updating the values within the existing instance. Is there no standard set-style methods to update these probabilities? If not, is there a recommended workaround (i.e. using a different class, making my own class, overriding something to make it accessible, etc.?)
A follow up would be whether or not this question is a moot effort. Would the compiled code actually be any more/less efficient by trying to avoid this new instance every loop? (I'm not knowledgeable enough to know how compilers would handle such things).
Code
A minimal example below:
package com.example.mypackage.learning;
import android.app.Activity;
import android.os.Bundle;
import org.apache.commons.math3.distribution.EnumeratedIntegerDistribution;
public class Qlearning extends Activity {
private int selectedAction;
private int[] actions = {0, 1, 2};
private double[] weights = {1.0, 1.0, 1.0};
private double[] qValues = {1.0, 1.0, 1.0};
private double qValuesSum;
EnumeratedIntegerDistribution enumeratedIntegerDistribution = new EnumeratedIntegerDistribution(actions, weights);
private final double alpha = 0.001;
int action;
double reward;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
while(true){
action = determineAction();
reward = determineReward();
learn(action, reward);
}
}
public void learn(int action, double reward) {
qValues[selectedAction] = (alpha * reward) + ((1.0 - alpha) * qValues[selectedAction]);
qValuesSum = 0;
for (int i = 0; i < qValues.length; i++){
qValuesSum += Math.exp(qValues[i]);
}
weights[selectedAction] = Math.exp(qValues[selectedAction]) / qValuesSum;
// *** This seems inefficient ***
EnumeratedIntegerDistribution enumeratedIntegerDistribution = new EnumeratedIntegerDistribution(actions, weights);
}
}
Please don't focus on the absence of the determineAction()
or determineReward()
methods, as this is simply a minimal example. You could easily just sub in fixed values there (e.g. 1, and 1.5) if you wanted a working example.
Also, I'm well aware of the infinite while loop that would be troublesome for a GUI, but again, just trying to reduce the code I have to show here to get the point across.
Edit:
In response to a comment I'm posting what I had for a similar class below. Note I haven't used this in over a year and things may be broken. Just posting for reference:
public class ActionDistribution{
private double reward = 0;
private double[] weights = {0.34, 0.34, 0.34};
private double[] qValues = {0.1, 0.1, 0.1};
private double learningRate = 0.1;
private double temperature = 1.0;
private int selectedAction;
public ActionDistribution(){}
public ActionDistribution(double[] weights, double[] qValues, double learningRate, double temperature){
this.weights = weights;
this.qValues = qValues;
this.learningRate = learningRate;
this.temperature = temperature;
}
public int actionSelect(){
double sumOfWeights = 0;
for (double weight: weights){
sumOfWeights = sumOfWeights + weight;
}
double randNum = Math.random() * sumOfWeights;
double selector = 0;
int iterator = -1;
while (selector < randNum){
try {
iterator++;
selector = selector + weights[iterator];
}catch (ArrayIndexOutOfBoundsException e){
Log.e("abcvlib", "weight index bound exceeded. randNum was greater than the sum of all weights. This can happen if the sum of all weights is less than 1.");
}
}
// Assigning this as a read-only value to pass between threads.
this.selectedAction = iterator;
// represents the action to be selected
return iterator;
}
public double[] getWeights(){
return weights;
}
public double[] getqValues(){
return qValues;
}
public double getQValue(int action){
return qValues[action];
}
public double getTemperature(){
return temperature;
}
public int getSelectedAction() {
return selectedAction;
}
public void setWeights(double[] weights) {
this.weights = weights;
}
public void setQValue(int action, double qValue) {
this.qValues[action] = qValue;
}
public void updateValues(double reward, int action){
double qValuePrev = getQValue(action);
// update qValues due to current reward
setQValue(action,(learningRate * reward) + ((1.0 - learningRate) * qValuePrev));
// update weights from new qValues
double qValuesSum = 0;
for (double qValue : getqValues()) {
qValuesSum += Math.exp(temperature * qValue);
}
// update weights
for (int i = 0; i < getWeights().length; i++){
getWeights()[i] = Math.exp(temperature * getqValues()[i]) / qValuesSum;
}
}
public double getReward() {
return reward;
}
public void setReward(double reward) {
this.reward = reward;
}
}