I am given a pytorch 2-D tensor with integers, and 2 integers that always appear in each row of the tensor.
I want to create a binary mask that will contain 1 between the two appearances of these 2 integers, otherwise 0. For example, if the integers are 4 and 2 and the 1-D array is [1,1,9,4,6,5,1,2,9,9,11,4,3,6,5,2,3,4]
, the returned mask will be: [0,0,0,1,1,1,1,1,0,0,0,1,1,1,1,1,0,0,0].
Is there any efficient and quick way to compute this mask without iterations?
Asked
Active
Viewed 5,089 times
-2

Codevan
- 538
- 3
- 20
2 Answers
4
Perhaps a bit messy, but it works without iterations. In the following I assume an example tensor m
to which I apply the solution, it's easier to explain with that instead of using general notations.
import torch
vals=[2,8]#let's assume those are the constant values that appear in each row
#target tensor
m=torch.tensor([[1., 2., 7., 8., 5.],
[4., 7., 2., 1., 8.]])
#let's find the indexes of those values
k=m==vals[0]
p=m==vals[1]
v=(k.int()+p.int()).bool()
nz_indexes=v.nonzero()[:,1].reshape(m.shape[0],2)
#let's create a tiling of the indexes
q=torch.arange(m.shape[1])
q=q.repeat(m.shape[0],1)
#you only need two masks, no matter the size of m. see explanation below
msk_0=(nz_indexes[:,0].repeat(m.shape[1],1).transpose(0,1))<=q
msk_1=(nz_indexes[:,1].repeat(m.shape[1],1).transpose(0,1))>=q
final_mask=msk_0.int() * msk_1.int()
print(final_mask)
and we get
tensor([[0, 1, 1, 1, 0],
[0, 0, 1, 1, 1]], dtype=torch.int32)
Regarding the two masks mask_0
and mask_1
in case it's not clear what they are, note nz_indexes[:,0]
containts, for each row of m
, the column index at which vals[0]
is found, and nz_indexes[:,1]
similarly contains, for each row of m
, the column index at which vals[1]
is found.

Ash
- 4,611
- 6
- 27
- 41
-
my code fails on the line of {nz_indexes=v.nonzero()[:,1].reshape(m.shape[0],2) }, why do you use the nonzero over v and take only the first column? – Codevan Nov 10 '20 at 10:38
-
1@Codevan Hmmm... Yeah I think this will fail if either one of the values `vals[0]` or `vals[1]` happens more than once in a single row. My assumption when writing the line that you mention was that each value will happen only once per row, and so since `v.nonzero()[:,0] is the row indexes, I could discard this column as I would already know that `v.nonzero()[0,:]` and `v.nonzero()[1:]` correspond to the row `0` of `m` and so on. – Ash Nov 10 '20 at 10:46
-
I think a possible way to fix that is to change `nonzero()` with a functions that return the first and last occurences of non-zero elements instead. – Ash Nov 10 '20 at 10:48
-
I edited my question accordingly. Do you know how to fit your answer now? Thanks! – Codevan Nov 12 '20 at 21:35
-3
Based completely on the previous solution, here is the revised one:
import torch
vals=[2,8]#let's assume those are the constant values that appear in each row
#target tensor
m=torch.tensor([[1., 2., 7., 8., 5., 2., 6., 5., 8., 4.],
[4., 7., 2., 1., 8., 2., 6., 5., 6., 8.]])
#let's find the indexes of those values
k=m==vals[0]
p=m==vals[1]
v=(k.int()+p.int()).bool()
nz_indexes=v.nonzero()[:,1].reshape(m.shape[0],4)
#let's create a tiling of the indexes
q=torch.arange(m.shape[1])
q=q.repeat(m.shape[0],1)
#you only need two masks, no matter the size of m. see explanation below
msk_0=(nz_indexes[:,0].repeat(m.shape[1],1).transpose(0,1))<=q
msk_1=(nz_indexes[:,1].repeat(m.shape[1],1).transpose(0,1))>=q
msk_2=(nz_indexes[:,2].repeat(m.shape[1],1).transpose(0,1))<=q
msk_3=(nz_indexes[:,3].repeat(m.shape[1],1).transpose(0,1))>=q
final_mask=msk_0.int() * msk_1.int() + msk_2.int() * msk_3.int()
print(final_mask)
and we finally get
tensor([[0, 1, 1, 1, 0, 1, 1, 1, 1, 0],
[0, 0, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int32)

Codevan
- 538
- 3
- 20
-
4For the future, please don't change the question **after** you have already accepted an answer, not only is that rude, but it makes people who answered you look like complete morons. As for the solution itself, it would have been better to find something that generalizes better: what if we there are N repetitions in the array? It doesn't seem super elegant/efficient to have N*2 mask_i arrays. – Ash Nov 13 '20 at 16:24
-
@Ash is right, and you should at least up-vote his answer. – Hovercraft Full Of Eels Nov 13 '20 at 16:27