Using torch==1.7.1
MRE:
lst = [torch.rand(4).reshape(1,-1) for _ in range(5)]
for each tensor, I want to get index of largest value.
max_indexes = [torch.argmax(tensor) for tensor in lst]
which outputting
[tensor(0), tensor(0), tensor(2), tensor(2), tensor(2)]
I want to use its positions to be able to grab appropriate value from list of actions. How can I remove tensor and just get max_index?
desired output looks like this:
[0, 0, 2, 2, 2]
Thanks!