First I have to say that i know, that Dropout is not common in reinforcement learning (RL). Here you can read more about the topic and why it maybe makes sense:
https://towardsdatascience.com/generalization-in-deep-reinforcement-learning-a14a240b155b
I am not sure how to implement Dropout in a Keras DQN. Usually (in supervised learning) Keras takes care on the task of turning the Dropout layer on/off, depending on whether you are training or testing. In my case (trading with RL) i do train on TRAIN data and TEST on holdout data and they are NOT equal. The model does over fit the TRAIN data and does not generalize well. I can see that it over fits, just by viewing the train results - it memorizes and perfectly trades the trained data. That's why i want to use Dropout.
EDIT: Since "krenerd" gave a different way to implement the function, I'll summarize all 4 (to me) known ways here:
WAY 1: Using K.set_learning_phase() with K.set_learning_phase(0) or K.set_learning_phase(1)
WAY 2: (https://stackoverflow.com/a/57439143/11122466)
Using a K (backend) function:
func = K.function(model.inputs + [K.learning_phase()], model.outputs)
run the model with dropout layers being active, i.e. learning_phase=1
preds = func(list_of_input_arrays + [1])
run the model with dropout layers being inactive, i.e. learning_phase=0
preds = func(list_of_input_arrays + [0])
WAY 3: (https://stackoverflow.com/a/57439143/11122466) "Another approach is to define a new model with the same architecture but without setting training=True, and then transfer the weights from the trained model to this new model." This is very slow for me, around 1.5 ms for per copy. Because it is slow i don't like that solution.
Way 4 (suggested by "krenerd"): "Call the model with model(x,training=True)" Here i get a Value Error: Layer INPUT was called with an input that isn't a symbolic tensor. Received type: <class 'numpy.ndarray'>. My input is a numpy array, i have to cast that to a tensor.
If you look at the following simple example for a DQN, modified/ taken from:
https://github.com/keon/deep-q-learning/blob/master/dqn.py
class DQNAgent:
def __init__(self, state_size, action_size):
self.state_size = state_size
self.action_size = action_size
self.memory = deque(maxlen=2000)
self.gamma = 0.95 # discount rate
self.epsilon = 1.0 # exploration rate
self.epsilon_min = 0.01
self.epsilon_decay = 0.995
self.learning_rate = 0.001
self.model = self._build_model()
def _build_model(self):
model = Sequential()
model.add(Dense(24, input_dim=self.state_size, activation='relu'))
model.add(Dense(24, activation='linear'))
model.add(Dropout(0.05))
model.add(ReLU())
model.add(Dense(self.action_size, activation='linear'))
model.compile(loss='mse',optimizer=Adam(lr=self.learning_rate))
return model
def act(self, state):
if np.random.rand() <= self.epsilon:
return random.randrange(self.action_size)
act_values = self.model.predict(state)
return np.argmax(act_values[0]) # returns action
def replay(self, batch_size):
minibatch = random.sample(self.memory, batch_size)
for state, action, reward, next_state, done in minibatch:
target = reward
if not done:
target = (reward + self.gamma *
np.amax(self.model.predict(next_state)[0]))
target_f = self.model.predict(state)
target_f[0][action] = target
self.model.fit(state, target_f, epochs=1, verbose=0)
if self.epsilon > self.epsilon_min:
self.epsilon *= self.epsilon_decay
ON TRAIN DATA: We call predict() twice in the replay() function, once for "state" and once for "next_state" (to create our target/label) and we call predict() in act(). Do we enable Dropout for both predict() calls in replay()? Do we enable Dropout for the predict() call in act()?
ON TEST DATA: No Exploration, Epsilon = 0. Only act() is used to evaluate the performance on unseen data. The data from TEST is NOT saved in the replay buffer. I think we do not use Dropout here?
2 Questions:
How/where would you insert the Dropout layer? Before or after the first Dense layer and/or before/after an Activation layer?
Which predict() calls have to be made with training = True / learning_phase = 1 and which not?