I have a HxWx3 tensor representing an RGB image and a HxWx3 mask (boolean) tensor as input. It is assumed that for each (i,j) in the mask tensor there's exactly one true value (that is exactly one of R\G\B is on). I want to apply the mask to the image to result in a HxW (or HxWx1) tensor V where V[i,j]='the matching R\G\B value according to the mask'.
Using Problem applying binary mask to an RGB image with numpy I was able to achieve the following:
>>> X*mask
tensor([[[ 9., 10.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 20.]],
[[ 0., 0.],
[30., 0.]]])
But as stated, I want a single dim HxW and not HxWx3 as result.