5
>>> a = torch.arange(12).reshape(2, 6)
>>> a
tensor([[ 0,  1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10, 11]])
>>> b = a[1:, :]
>>> b.storage() is a.storage()
False

But

>>> b[0, 0] = 999
>>> b, a # both tensors are changed
(tensor([[999,   7,   8,   9,  10,  11]]),
 tensor([[  0,   1,   2,   3,   4,   5],
         [999,   7,   8,   9,  10,  11]]))

What is exactly the objects that stores tensor data? How can I make check if 2 tensors share memory?

desertnaut
  • 57,590
  • 26
  • 140
  • 166

1 Answers1

5

torch.Tensor.storage() returns a new instance of torch.Storage on every invocation. You can see this in the following

a.storage() is a.storage()
# False

To compare the pointers to the underlying data, you can use the following:

a.storage().data_ptr() == b.storage().data_ptr()
# True

There is a discussion of how to determine whether pytorch tensors share memory in this pytorch forum post.


Note the difference between a.data_ptr() and a.storage().data_ptr(). The first returns the pointer to the first element of the tensor, whereas the second seems to the point to the memory address of the underlying data (not the sliced view), though it is not documented.

Knowing the above, we can understand why a.data_ptr() is different from b.data_ptr(). Consider the following code:

import torch

a = torch.arange(4, dtype=torch.int64)
b = a[1:]
b.data_ptr() - a.data_ptr()
# 8

The address of the first element of b is 8 more than the first element of a because we sliced to remove the first element, and each element is 8 bytes (the dtype is 64-bit integer).

If we use the same code as above but use an 8-bit integer data type, the memory address will be different by one.

jkr
  • 17,119
  • 2
  • 42
  • 68