I have three 3d PyTorch tensors, a
, b
and c
. First, for each row in b
, I want to find all the elements that also exist in the respective row in a
. Then, I want to create a new tensor containing the elements of c
, based on the overlap between a
and b
. None of the rows contain duplicates. Here is a dummy example:
a = torch.tensor([[[0, 3, 2, 7]],
[[7, 8, 5, 9]]])
b = torch.tensor([[[3, 1, 7, 4]],
[[5, 8, 2, 3]]])
c = torch.tensor([[[7, 2, 8, 4]],
[[1, 6, 7, 3]]])
For instance, element b[0,0,0]
is 3
. The corresponding element in c
is 7
. This 7
should be moved to position [0, 0, 1]
in the output tensor, seeing as this position is where we can find the 3
in tensor a
. Elements that only exist in b
but not in the corresponding row in a
, should be set to 0
in the output tensor.
Here is what the complete output should look like:
tensor([[[0, 7, 0, 8]],
[[0, 6, 1, 0]]])
The final tensor should still have 3 dimensions.
I tried using torch.isin to create a boolean tensor showing which elements of b are also in a. However, since this operates on the entire tensor rather than row-wise, it did not work for this problem.
This post asks a question that is similar to the first part of my question, but only for 1d NumPy arrays.
This post asks about applying torch.isin to a 2d tensor, but the only suggested answer doesn't work for 3d tensors.