I am trying to modify a masked pytorch tensor inside a function. I observe the same behaviour for numpy arrays.
from torch import tensor
def foo(x):
"""
Minimal example.
The actual function is complex.
"""
x *= -1
y = tensor([1,2,3])
mask = [False, True, False]
foo(y[mask])
print(y)
# Result: tensor([1, 2, 3]). Expected: tensor([1, -2, 3])
There are two obvious solutions that I can think of. Both have shortcomings I would like to avoid.
def foo1(x):
return -x
y = tensor([1,2,3])
mask = [False, True, False]
y[mask] = foo1(y[mask])
This creates an copy of y[mask]
, which is not ideal for my RAM-bound application.
def foo2(x, m):
x[m] *= -1
y = tensor([1,2,3])
mask = [False, True, False]
foo2(y, mask)
This works without a copy, but makes the function messy. It has to be aware of the mask and types. E.g. it won't work directly on scalars.
What is the idiomatic way to handle this problem?