1

I have already checked this question and confirmed this is not a duplicate issue.

Problem:

I have implemented an agent that uses a DQN with TensorFlow to learn the optimal policy of a game called 'dots and boxes'. The algorithm appears to actually be working based on a rolling average win rate against a random player, but the issue is that eventually the Q-Values output by the DQN become too big to express and become [inf], at which point an error is raised and the Q Function is no longer usable.

My reward structure is very simple. The agent gets -1 for a loss and 1 for a win. I have also clipped all the gradients to be between -1 and 1, which was an attempt to stave off this behavior. Reducing the learning rate appears to stave off this Q-Value explosion but, given enough time, it happens regardless.

I've included relevant code below:

Gradient Clipping:

     # Clip gradients to prevent gradient explosion
    gradients = self.optimizer.compute_gradients(self.loss)
    clipped_gradients = [(tf.clip_by_value(grad,-1.,1.), var) for grad, var in gradients]
    self.update_model = self.optimizer.apply_gradients(clipped_gradients)

(The optimizer is the RMSProp Optimizer)

Update Method:

 def td_update(self, current_state, last_action, next_state, reward):
    """Updates the Q_function according to the SARSA update algorithm"""
    # Update the replay table
    self.replay_table[self.transition_count % self.replay_size] = (current_state, last_action, next_state, reward)
    self.transition_count = (self.transition_count + 1)

    # Don't start learning until transition table has some data
    if self.transition_count >= self.update_size * 20:
        if self.transition_count == self.update_size * 20:
            print("Replay Table is Ready\n")

        # Get a random subsection of the replay table for mini-batch update
        random_tbl = random.choice(self.replay_table[:min(self.transition_count,self.replay_size)],size=self.update_size)
        feature_vectors = np.vstack(random_tbl['state'])
        actions = random_tbl['action']
        next_feature_vectors = np.vstack(random_tbl['next_state'])
        rewards = random_tbl['reward']

        # Get the indices of the non-terminal states
        non_terminal_ix = np.where([~np.any(np.isnan(next_feature_vectors),axis=(1,2,3))])[1]

        q_current = self.get_Q_values(feature_vectors)
        # Default q_next will be all zeros (this encompasses terminal states)
        q_next = np.zeros([self.update_size,len(self._environment.action_list)])
        q_next[non_terminal_ix] = self.get_Q_values(next_feature_vectors[non_terminal_ix])

        # The target should be equal to q_current in every place
        target = q_current.copy()

        # Only actions that have been taken should be updated with the reward
        # This means that the target - q_current will be [0 0 0 0 0 0 x 0 0....] 
        # so the gradient update will only be applied to the action taken
        # for a given feature vector.
        target[np.arange(len(target)), actions] += (rewards + self.gamma*q_next.max(axis=1))

        # Logging
        if self.log_file is not None:
            print ("Current Q Value: {}".format(q_current),file=self.log_file)
            print ("Next Q Value: {}".format(q_next),file=self.log_file)
            print ("Current Rewards: {}".format(rewards),file=self.log_file)
            print ("Actions: {}".format(actions),file=self.log_file)
            print ("Targets: {}".format(target),file=self.log_file)

            # Log some of the gradients to check for gradient explosion
            loss, output_grad, conv_grad = self.sess.run([self.loss,self.output_gradient,self.convolutional_gradient],
                                                         feed_dict={self.target_Q: target, self.input_matrix: feature_vectors})
            print ("Loss: {}".format(loss),file=self.log_file)
            print ("Output Weight Gradient: {}".format(output_grad),file=self.log_file)
            print ("Convolutional Gradient: {}".format(conv_grad),file=self.log_file)

        # Update the model
        self.sess.run(self.update_model, feed_dict={self.target_Q: target, self.input_matrix: feature_vectors})

I've hand checked the target and I believe that is has the correct values based on what I know about the algorithm and based on my test results. If anyone can give me insight as to why this might happen or what I can do to prevent it, I'd be incredibly grateful. If I need to provide more information, please let me know.

Community
  • 1
  • 1
mattdeak
  • 190
  • 1
  • 13
  • 1
    I didn't see a target network here. Seems the `q_current` and `q_next` come from the same network, which is not the way described in the DQN paper. Check [my implementation](https://github.com/zaxliu/dqn4wirelesscontrol/blob/master/rl/qnn_theano.py) if you like. – zaxliu Apr 14 '17 at 02:42
  • 1
    And actually I tend to use `tanh()` as the output non-linearity. That way the gradient is always bounded and you can use a scaling factor to prevent output from satuation. – zaxliu Apr 14 '17 at 02:44
  • @zaxliu Thanks for the reply, you're right in that I wasn't using a target network. I had thought that, while the target network was an improvement on DQN it wasn't necessary for DQN. Would that be the cause of a Q-Value explosion? – mattdeak Apr 16 '17 at 19:11
  • Also, I have tried a tanh non-linearity and gotten much better results since asking this question, but I was wondering if it was mathematically sound given that you are no longer trying to find the actual Q-Value. – mattdeak Apr 16 '17 at 19:12
  • Also, if you use a tanh non-linearity in your network, should you apply tanh to your target before running an update on the network? – mattdeak Apr 16 '17 at 19:34
  • You are right that after applying `tanh()` the network output may not be the true Q value, which is often the case is the True Q Value is beyond the [-1, 1] range. But that is okay for two reasons: first the Q value will go in the right direction even the target is squashed version of the true target, which is guaranteed by the Bellman equation. Second, you don't need the network output to be exact to make the inference correct. What is important is the relative magnitude of different action values. These two factors makes `tanh` usable. – zaxliu Oct 19 '17 at 03:14

0 Answers0