I want to create a deep q network with deeplearning4j, but can not figure out how to update the weights of my neural network using the calculated loss.
(i was mainly following this Artical)
public class DDQN {
private static final double learningRate = 0.01;
private final MultiLayerNetwork qnet;
private final MultiLayerNetwork tnet;
private final ReplayMemory mem = new ReplayMemory(20000);
private final Batch batch = new Batch(1000);
public DDQN(int input, int hidden, int output) {
ListBuilder conf = new NeuralNetConfiguration.Builder().seed(Rnd.seed).weightInit(WeightInit.XAVIER)
.updater(new Nesterovs(learningRate, 0.9)).list()
.layer(new DenseLayer.Builder().nIn(input).nOut(hidden).activation(Activation.IDENTITY).build())
.layer(new DenseLayer.Builder().nIn(input).nOut(hidden).activation(Activation.IDENTITY).build())
.layer(new DenseLayer.Builder().nIn(hidden).nOut(output).activation(Activation.IDENTITY)
.build());
qnet = new MultiLayerNetwork(conf.build());
qnet.init();
tnet = qnet.clone();
}
public INDArray tmpState = null;
public int tmpAction = -1;
public int getAction(double[] state) {
tmpState = Nd4j.createFromArray(new double[][] { state });
tmpAction = tnet.predict(tmpState)[0];
return tmpAction;
}
public void addResult(double reward, INDArray newState) {
mem.add(tmpState, tmpAction, reward, newState);
}
public void train(int size) {
mem.fillBatch(batch);
for (int i = 0; i < batch.size(); i++) {
// get q value of choosen action
INDArray out = qnet.output(batch.states[i]);
double q0 = out.getRow(0).getDouble(batch.actions[i]);
// get highest q value of the next state
out = tnet.output(batch.newStates[i]);
double q1 = out.maxNumber().doubleValue();
// calc mse
double err = q0 - (batch.rewards[i] + q1);
double mse = err * err;
// update neural net
// ??????
}
}
}
Replay Memory: (stores what the ai experienced for later training)
public class ReplayMemory {
private final INDArray[] states;
private final int[] actions;
private final double[] rewards;
private final INDArray[] newStates;
private int pos = 0;
private boolean filled = false;
public ReplayMemory(int size) {
states = new INDArray[size];
actions = new int[size];
rewards = new double[size];
newStates = new INDArray[size];
}
public void fillBatch(Batch b) {
final int max = filled ? states.length : pos+1;
int r;
for(int i=0; i<b.states.length; i++) {
r = Rnd.r.nextInt(max);
b.states[i] = states[r];
b.actions[i] = actions[r];
b.rewards[i] = rewards[r];
b.newStates[i] = newStates[r];
}
}
public void add(INDArray state, int action, double reward, INDArray newState) {
this.states[pos] = state;
this.actions[pos] = action;
this.rewards[pos] = reward;
this.newStates[pos] = newState;
if(++pos == this.size()) {
pos = 0;
filled = true;
}
}
public int size() {
return states.length;
}
}
Batch: (temporary stores the current batch of experiences during training)
public class Batch {
public final INDArray[] states;
public final int[] actions;
public final double[] rewards;
public final INDArray[] newStates;
public Batch(int size) {
states = new INDArray[size];
actions = new int[size];
rewards = new double[size];
newStates = new INDArray[size];
}
public int size() {
return states.length;
}
}
I already tried using google and reading documentation, with no luck.