1

For a given 2D tensor I want to retrieve all indices where the value is 1. I expected to be able to simply use torch.nonzero(a == 1).squeeze(), which would return tensor([1, 3, 2]). However, instead, torch.nonzero(a == 1) returns a 2D tensor (that's okay), with two values per row (that's not what I expected). The returned indices should then be used to index the second dimension (index 1) of a 3D tensor, again returning a 2D tensor.

import torch

a = torch.Tensor([[12, 1, 0, 0],
                  [4, 9, 21, 1],
                  [10, 2, 1, 0]])

b = torch.rand(3, 4, 8)

print('a_size', a.size())
# a_size torch.Size([3, 4])
print('b_size', b.size())
# b_size torch.Size([3, 4, 8])

idxs = torch.nonzero(a == 1)
print('idxs_size', idxs.size())
# idxs_size torch.Size([3, 2])

print(b.gather(1, idxs))

Evidently, this does not work, leading to aRunTimeError:

RuntimeError: invalid argument 4: Index tensor must have same dimensions as input tensor at C:\w\1\s\windows\pytorch\aten\src\TH/generic/THTensorEvenMoreMath.cpp:453

It seems that idxs is not what I expect it to be, nor can I use it the way I thought. idxs is

tensor([[0, 1],
        [1, 3],
        [2, 2]])

but reading through the documentation I don't understand why I also get back the row indices in the resulting tensor. Now, I know I can get the correct idxs by slicing idxs[:, 1] but then still, I cannot use those values as indices for the 3D tensor because the same error as before is raised. Is it possible to use the 1D tensor of indices to select items across a given dimension?

kmario23
  • 57,311
  • 13
  • 161
  • 150
Bram Vanroy
  • 27,032
  • 24
  • 137
  • 239

4 Answers4

5

You could simply slice them and pass it as the indices as in:

In [193]: idxs = torch.nonzero(a == 1)     
In [194]: c = b[idxs[:, 0], idxs[:, 1]]  

In [195]: c   
Out[195]: 
tensor([[0.3411, 0.3944, 0.8108, 0.3986, 0.3917, 0.1176, 0.6252, 0.4885],
        [0.5698, 0.3140, 0.6525, 0.7724, 0.3751, 0.3376, 0.5425, 0.1062],
        [0.7780, 0.4572, 0.5645, 0.5759, 0.5957, 0.2750, 0.6429, 0.1029]])

Alternatively, an even simpler & my preferred approach would be to just use torch.where() and then directly index into the tensor b as in:

In [196]: b[torch.where(a == 1)]  
Out[196]: 
tensor([[0.3411, 0.3944, 0.8108, 0.3986, 0.3917, 0.1176, 0.6252, 0.4885],
        [0.5698, 0.3140, 0.6525, 0.7724, 0.3751, 0.3376, 0.5425, 0.1062],
        [0.7780, 0.4572, 0.5645, 0.5759, 0.5957, 0.2750, 0.6429, 0.1029]])

A bit more explanation about the above approach of using torch.where(): It works based on the concept of advanced indexing. That is, when we index into the tensor using a tuple of sequence objects such as tuple of tensors, tuple of lists, tuple of tuples etc.

# some input tensor
In [207]: a  
Out[207]: 
tensor([[12.,  1.,  0.,  0.],
        [ 4.,  9., 21.,  1.],
        [10.,  2.,  1.,  0.]])

For basic slicing, we would need a tuple of integer indices:

   In [212]: a[(1, 2)] 
   Out[212]: tensor(21.)

To achieve the same using advanced indexing, we would need a tuple of sequence objects:

# adv. indexing using a tuple of lists
In [213]: a[([1,], [2,])] 
Out[213]: tensor([21.])

# adv. indexing using a tuple of tuples
In [215]: a[((1,), (2,))]  
Out[215]: tensor([21.])

# adv. indexing using a tuple of tensors
In [214]: a[(torch.tensor([1,]), torch.tensor([2,]))] 
Out[214]: tensor([21.])

And the dimension of the returned tensor would always be one dimension less than the dimension of the input tensor.

kmario23
  • 57,311
  • 13
  • 161
  • 150
  • 1
    I very much love the simplicity of your second suggestion. Could you explain a bit more *why* this works? Since `torch.where(a == 1)` returns a tuple. How does slicing a tensor with a tuple work like that? – Bram Vanroy Sep 27 '19 at 16:47
  • @BramVanroy added some explanation :) – kmario23 Sep 27 '19 at 17:14
1

Assuming that b's three dimensions are batch_size x sequence_length x features (b x s x feats), the expected results can be achieved as follows.

import torch

a = torch.Tensor([[12, 1, 0, 0],
                  [4, 9, 21, 1],
                  [10, 2, 1, 0]])

b = torch.rand(3, 4, 8)
print(b.size())
# b x s x feats
idxs = torch.nonzero(a == 1)[:, 1]
print(idxs.size())
# b
c = b[torch.arange(b.size(0)), idxs]
print(c.size())
# b x feats
Bram Vanroy
  • 27,032
  • 24
  • 137
  • 239
0
import torch

a = torch.Tensor([[12, 1, 0, 0],
                  [4, 9, 21, 1],
                  [10, 2, 1, 0]])

b = torch.rand(3, 4, 8)

print('a_size', a.size())
# a_size torch.Size([3, 4])
print('b_size', b.size())
# b_size torch.Size([3, 4, 8])

#idxs = torch.nonzero(a == 1, as_tuple=True)
idxs = torch.nonzero(a == 1)
#print('idxs_size', idxs.size())

print(torch.index_select(b,1,idxs[:,1]))
  • No, not at all. First of all `z` is a tuple, second my final goal is to get an index from `b`. – Bram Vanroy Sep 27 '19 at 14:52
  • yeah sorry I completely didn't read your question correctly...for your final question "Is it possible to use the 1D tensor of indices to select items across a given dimension", ,does torch.index_select(b,1,idxs[:,1]) give you what you need? – user2066337 Sep 27 '19 at 15:26
  • No. That'll return a 3D matrix. I expect a 2D one. See my answer. – Bram Vanroy Sep 27 '19 at 15:57
  • ah, gotcha. glad you figured it out. sidenote just make sure there aren't 2 1's in a row – user2066337 Sep 27 '19 at 16:22
0

As a supplementary of @kmario23's solution, you can still achieve the same results like

b[torch.nonzero(a==1,as_tuple=True)]
zihaozhihao
  • 4,197
  • 2
  • 15
  • 25