0

Just to be upfront, this is basically the same question as Numpy array loss of dimension when masking, but for PyTorch tensors rather than NumPy arrays. The solution(s) to that questions, using the equivalent PyTorch function torch.where or masked tensors work, but I found a Google search about this for PyTorch tensors did quickly not hit upon an answer. So, I thought an equivalent StackOverflow tagged question might be useful for others!)

I have a 2D PyTorch tensor (although it could have more dimensions) to which I want to apply an equivalently shaped binary mask. However, when I apply the mask the output is just 1-dimensional. How can I keep the same dimensions as the original tensor after application of the mask?

E.g., for

import torch

x = torch.tensor([[1.0, 2.0, 8.0], [-4.0, 0.0, 3.0]])
mask = x >=2.0
print(x[mask])
tensor([2., 8., 3.])

the output is now 1D rather than 2D.

Matt Pitkin
  • 3,989
  • 1
  • 18
  • 32
  • maybe I do not understand the question well, but `torch.where(mask, x, 0)` does preserve the dimension, and replace the "filtered out" value by 0 – JackRed Aug 14 '23 at 14:40
  • 1
    Yes, that does work, which I say in the blurb at the start of the question. I meant to also add that as an/the answer, but hadn't got round to it yet. Feel free to post the answer yourself. I mainly posted the question as (hopefully) an aid to others searching for the same information. – Matt Pitkin Aug 14 '23 at 14:46
  • makes sense. I guess a community answer in this case would be the best – JackRed Aug 14 '23 at 14:48

2 Answers2

1

Using torch's where function we'll get the row, column tensors like:

import torch

x = torch.tensor([[1.0, 2.0, 8.0], [-4.0, 0.0, 3.0]])
mask = x >=2.0
print(torch.where(mask))
print(x[torch.where(mask)])

Which outputs:

(tensor([0, 0, 1]), tensor([1, 2, 2]))
tensor([2., 8., 3.])

However, plugging that into x will output only the masked values leaving us with a 1D tensor because we're getting rid of all the values in the tensor that don't evaluate to True (so it can't be the same shape as its original tensor because it has fewer values so it gets flattened into a single dimension).

If you want x to be its original shape but only where masked values are Truthy then we could fill a tensor of zeros with ones at those indexes:

import torch

x = torch.tensor([[1.0, 2.0, 8.0], [-4.0, 0.0, 3.0]])
mask = x >=2.0
masked_x = torch.zeros((x.shape))
masked_x[torch.where(mask)] = 1.0
print(masked_x)

Outputs:

tensor([[0., 1., 1.],
        [0., 0., 1.]])

So now masked_x is a tensor of zeros the same shape as x but with ones where the mask is Truthy.

If you want masked_x to comprise of x's values where the mask is Truthy then:

masked_x = torch.zeros((x.shape))
masked_x[torch.where(mask)] = x[torch.where(mask)]
print(masked_x)

Outputs:

tensor([[0., 2., 8.],
        [0., 0., 3.]])

a tensor of zeros where mask was Falsey.

If you want something else, please clarify.

Ori Yarden PhD
  • 1,287
  • 1
  • 4
  • 8
  • 1
    You can simplify the use of `where` by just doing `torch.where(mask, x, 0)`, which gives your final result. – Matt Pitkin Aug 14 '23 at 14:49
  • @Matt Pitkin I kept it like how I would do it with `numpy`'s `where` just for demonstration purposes/keeping it simple but good point! – Ori Yarden PhD Aug 14 '23 at 14:51
1

Just like suggested in the numpy question, torch.where will do the trick.

To keep the same dimensionality you are going to need a fill value. In the example below I use 0, but you could also use torch.nan

torch.where(mask, x, 0)

returns

tensor([[0., 2., 8.],
        [0., 0., 3.]])
JackRed
  • 1,188
  • 2
  • 12
  • 28