0

I have three 3d PyTorch tensors, a, b and c. First, for each row in b, I want to find all the elements that also exist in the respective row in a. Then, I want to create a new tensor containing the elements of c, based on the overlap between a and b. None of the rows contain duplicates. Here is a dummy example:

a = torch.tensor([[[0, 3, 2, 7]],
                  [[7, 8, 5, 9]]])

b = torch.tensor([[[3, 1, 7, 4]],
                  [[5, 8, 2, 3]]])

c = torch.tensor([[[7, 2, 8, 4]],
                  [[1, 6, 7, 3]]])

For instance, element b[0,0,0] is 3. The corresponding element in c is 7. This 7 should be moved to position [0, 0, 1] in the output tensor, seeing as this position is where we can find the 3 in tensor a. Elements that only exist in b but not in the corresponding row in a, should be set to 0 in the output tensor.

Here is what the complete output should look like:

tensor([[[0, 7, 0, 8]],
        [[0, 6, 1, 0]]])

The final tensor should still have 3 dimensions.

I tried using torch.isin to create a boolean tensor showing which elements of b are also in a. However, since this operates on the entire tensor rather than row-wise, it did not work for this problem.

This post asks a question that is similar to the first part of my question, but only for 1d NumPy arrays.

This post asks about applying torch.isin to a 2d tensor, but the only suggested answer doesn't work for 3d tensors.

akup
  • 11
  • 2

0 Answers0