2

Using torch==1.7.1

MRE:

lst = [torch.rand(4).reshape(1,-1) for _ in range(5)]

for each tensor, I want to get index of largest value.

max_indexes = [torch.argmax(tensor) for tensor in lst]

which outputting

[tensor(0), tensor(0), tensor(2), tensor(2), tensor(2)]

I want to use its positions to be able to grab appropriate value from list of actions. How can I remove tensor and just get max_index?

desired output looks like this:

[0, 0, 2, 2, 2]

Thanks!

haneulkim
  • 4,406
  • 9
  • 38
  • 80

0 Answers0