2

Suppose I have a tensor 2D tensor x of shape (n,m). How can I extend the first dimension of the tensor by appending zero rows in x by specifying the indices of where the zero rows will be located in the resulting tensor? For a concrete example:

x = torch.tensor([[1,1,1],
                  [2,2,2],
                  [3,3,3],
                  [4,4,4]])

And I want to append 2 zero rows such that their row-index will be 1,3, respectively, in the resulting tensor? I.e. in the example the result would be

X = torch.tensor([1,1,1],
                 [0,0,0],
                 [2,2,2],
                 [0,0,0],
                 [3,3,3],
                 [4,4,4]])

I tried using F.pad and reshape.

ChrisNick92
  • 123
  • 4

2 Answers2

1

You can use torch.cat:

def insert_zeros(x, all_j):
    zeros_ = torch.zeros_like(x[:1])
    pieces = []
    i      = 0
    for j in all_j + [len(x)]:
        pieces.extend([x[i:j],
                       zeros_])
        i = j
    return torch.cat(pieces[:-1],
                      dim=0     )

# insert_zeros(x, [1,2])
# tensor([[1, 1, 1],
#         [0, 0, 0],
#         [2, 2, 2],
#         [0, 0, 0],
#         [3, 3, 3],
#         [4, 4, 4]])

This code is compatible with backpropagation, since the tensors are not modified in-place.


More information: What's the difference between torch.stack() and torch.cat()?

C-3PO
  • 1,181
  • 9
  • 17
  • Thank you for your answer! Ideally I would like to this more generic. Suppose the indices `[1,2]` is a given sequence with `n` elements. Then this approach cannot be generalized easily. – ChrisNick92 Nov 13 '22 at 18:26
  • 1
    Hi, I edited my solution. Now it is compatible with `n` insertion elements. Cheers, – C-3PO Nov 13 '22 at 18:33
1

You can use torch.tensor.index_add_.

import torch

zero_index = [1, 3]
size = (6, 3)

x = torch.tensor([[1,1,1],
                  [2,2,2],
                  [3,3,3],
                  [4,4,4]])

t = torch.zeros(size, dtype=torch.int64)
index = torch.tensor([i for i in range(size[0]) if i not in zero_index])
# index -> tensor([0, 2, 4, 5])

t.index_add_(0, index, x)
print(t)

Output:

tensor([[1, 1, 1],
        [0, 0, 0],
        [2, 2, 2],
        [0, 0, 0],
        [3, 3, 3],
        [4, 4, 4]])
I'mahdi
  • 23,382
  • 5
  • 22
  • 30