I am reading through the DQN implementation in keras-rl /rl/agents/dqn.py
and see that in the compile()
step essentially 3 keras models are instantiated:
self.model
: provides q value predictionsself.trainable_model
: same asself.model
but has the loss function we want to trainself.target_model
: target model which provides the q targets and is updated everyk
steps with the weights fromself.model
The only model on which train_on_batch()
is called is trainable_model
, however - and this is what I don't understand - this also updates the weights of model
.
In the definition of trainable_model
one of the output tensors y_pred
is referencing the output from model
:
y_pred = self.model.output
y_true = Input(name='y_true', shape=(self.nb_actions,))
mask = Input(name='mask', shape=(self.nb_actions,))
loss_out = Lambda(clipped_masked_error, output_shape=(1,), name='loss')([y_true, y_pred, mask])
ins = [self.model.input] if type(self.model.input) is not list else self.model.input
trainable_model = Model(inputs=ins + [y_true, mask], outputs=[loss_out, y_pred])
When trainable_model.train_on_batch()
is called, BOTH the weights in trainable_model
and in model
change. I am surprised because even though the two models reference the same output tensor object (trainable_model.y_pred = model.output
), the instantiation of trainable_model = Model(...)
should also instantiate a new set of weights, no?
Thanks for the help!