5

This is an issue I'm running while convertinf DQN to Double DQN for the cartpole problem. I'm getting close to figuring it out.

tensor([0.1205, 0.1207, 0.1197, 0.1195, 0.1204, 0.1205, 0.1208, 0.1199, 0.1206,
        0.1199, 0.1204, 0.1205, 0.1199, 0.1204, 0.1204, 0.1203, 0.1198, 0.1198,
        0.1205, 0.1204, 0.1201, 0.1205, 0.1208, 0.1202, 0.1205, 0.1203, 0.1204,
        0.1205, 0.1206, 0.1206, 0.1205, 0.1204, 0.1201, 0.1206, 0.1206, 0.1199,
        0.1198, 0.1200, 0.1206, 0.1207, 0.1208, 0.1202, 0.1201, 0.1210, 0.1208,
        0.1205, 0.1205, 0.1201, 0.1193, 0.1201, 0.1205, 0.1207, 0.1207, 0.1195,
        0.1210, 0.1204, 0.1209, 0.1207, 0.1187, 0.1202, 0.1198, 0.1202])
tensor([ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True, False,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True, False,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True])

As you can see here two tensors. The first has the q values I want but, some values need to be changed to zeros because of it an end state. The second tensor shows where it will be zeros.

At the index where the Boolean value is false is the equivalent spot for where the upper tensor needs to be zeros. I am not sure how to do that.

Muhammad Usman Bashir
  • 1,441
  • 2
  • 14
  • 43
DmiSH
  • 85
  • 1
  • 1
  • 4

2 Answers2

5

You can use torch.where - torch.where(condition, x, y)

Ex.:

>>> x = tensor([0.2853, 0.5010, 0.9933, 0.5880, 0.3915, 0.0141, 0.7745,  
                0.0588, 0.4939, 0.0849])
>>> condition = tensor([False,  True,  True,  True, False, False,  True,  
                        False, False, False])

>>> # It's equivalent to `torch.where(condition, x, tensor(0.0))`
>>> x.where(condition, tensor(0.0))
tensor([0.0000, 0.5010, 0.9933, 0.5880, 0.0000, 0.0000, 0.7745,  
        0.0000, 0.0000,0.0000])
Dishin H Goyani
  • 7,195
  • 3
  • 26
  • 37
3

If your above tensor is the value tensor and the bottom one is the decision tensor, then

value_tensor[decision_tensor==False] = 0

Moreover, you could also convert them to numpy arrays and perform the same operation and it should work.

Anurag Reddy
  • 1,159
  • 11
  • 19