2
import numpy as np
import torch
a = torch.zeros(5)
b = torch.tensor(tuple((0,1,0,1,0)),dtype=torch.uint8)
c= torch.tensor([7.,9.])
print(a[b].size())
a[b]=c
print(a)

torch.Size([2])
tensor([0., 7., 0., 9., 0.])

I am struggling to understand how this works. I initially thought the above code was using Fancy indexing but I realised that values from c tensors are getting copied corresponding to the indices marked 1. Also, if I don't specify dtype of b as uint8 then the above code does not work. Can someone please explain me the mechanism of the above code.

Abhishek Kishore
  • 340
  • 2
  • 13

1 Answers1

2

Indexing with arrays works the same as in numpy and most other vectorized math packages I am aware of. There are two cases:

  1. When b is of type uint8 (think boolean, pytorch doesn't distinguish bool from uint8), a[b] is a 1-d array containing the subset of values of a (a[i]) for which the corresponding in b (b[i]) was nonzero. These values are aliased to the original a so if you modify them, their corresponding locations will change as well.

  2. The alternative type you can use for indexing is an array of int64, in which case a[b] creates an array of shape (*b.shape, *a.shape[1:]). Its structure is as if each element of b (b[i]) was replaced by a[i]. In other words, you create a new array by specifying from which indexes of a should the data be fetched. Again, the values are aliased to the original a, so if you modify a[b] the values of a[b[i]], for each i, will change. An example usecase is shown in this question.

These two modes are explained for numpy in integer array indexing and boolean array indexing, where for the latter you have to keep in mind that pytorch uses uint8 in place of bool.

Also, if your goal is to copy data from one tensor to another you have to keep in mind that an operation like a[ixs] = b[ixs] is an in-place operation (a is modified in place), which my not play well with autograd. If you want to do out of place masking, use torch.where. An example usecase is shown in this answer.

Jatentaki
  • 11,804
  • 4
  • 41
  • 37
  • Thanks for explaining – Abhishek Kishore Dec 17 '18 at 14:09
  • Can you suggest any documentation for the behaviour of arrays you explained above – Abhishek Kishore Dec 17 '18 at 14:16
  • Found this http://www.math.buffalo.edu/~badzioch/MTH337/PT/PT-boolean_numpy_arrays/PT-boolean_numpy_arrays.html – Abhishek Kishore Dec 17 '18 at 14:24
  • 1
    [This](https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html). The two relevant paragraphs are [integer array indexing](https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#integer-array-indexing) and [boolean array indexing](https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#boolean-array-indexing), where for the latter you have to keep in mind that pytorch uses `uint8` in place of `bool`. – Jatentaki Dec 17 '18 at 14:32
  • 1
    I expanded my answer to also mention `torch.where` which may be the actual solution if you want to copy data between tensors. – Jatentaki Dec 17 '18 at 14:39