2

So I am training a DDQN to play connect four at the moment. At each state, the network predicts the action the best action and moves accordingly. The code looks basically like follows:

for epoch in range(num_epochs):
        for i in range(batch_size):
                while game is not finished:
                        action = select_action(state)
                        new_state = play_move(state, action)
                        pred, target = get_pred_target(state, new_state, action)
                        preds = torch.cat([preds, pred])
                        targets = torch.cat([targets, target)]
        loss = loss_fn(preds, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

While training, the network is getting a little bit better, but nowhere as good as I would expect. Thinking about it, I am wondering now, whether I have actually correctly implemented the loss.backward() call. The point is, I am saving all the predictions and targets for each move in the tensors preds and targets. However, I am not tracking the states, that have led to these predictions and targets. But isn't that necessary for the backward propagation, or is this information somehow saved?

Thank you very much!

spadel
  • 998
  • 2
  • 16
  • 40

0 Answers0