For a given 2D tensor I want to retrieve all indices where the value is 1
. I expected to be able to simply use torch.nonzero(a == 1).squeeze()
, which would return tensor([1, 3, 2])
. However, instead, torch.nonzero(a == 1)
returns a 2D tensor (that's okay), with two values per row (that's not what I expected). The returned indices should then be used to index the second dimension (index 1) of a 3D tensor, again returning a 2D tensor.
import torch
a = torch.Tensor([[12, 1, 0, 0],
[4, 9, 21, 1],
[10, 2, 1, 0]])
b = torch.rand(3, 4, 8)
print('a_size', a.size())
# a_size torch.Size([3, 4])
print('b_size', b.size())
# b_size torch.Size([3, 4, 8])
idxs = torch.nonzero(a == 1)
print('idxs_size', idxs.size())
# idxs_size torch.Size([3, 2])
print(b.gather(1, idxs))
Evidently, this does not work, leading to aRunTimeError:
RuntimeError: invalid argument 4: Index tensor must have same dimensions as input tensor at C:\w\1\s\windows\pytorch\aten\src\TH/generic/THTensorEvenMoreMath.cpp:453
It seems that idxs
is not what I expect it to be, nor can I use it the way I thought. idxs
is
tensor([[0, 1],
[1, 3],
[2, 2]])
but reading through the documentation I don't understand why I also get back the row indices in the resulting tensor. Now, I know I can get the correct idxs by slicing idxs[:, 1]
but then still, I cannot use those values as indices for the 3D tensor because the same error as before is raised. Is it possible to use the 1D tensor of indices to select items across a given dimension?