110

OpenAI's REINFORCE and actor-critic example for reinforcement learning has the following code:

REINFORCE:

policy_loss = torch.cat(policy_loss).sum()

actor-critic:

loss = torch.stack(policy_losses).sum() + torch.stack(value_losses).sum()

One is using torch.cat, the other uses torch.stack, for similar use cases.

As far as my understanding goes, the doc doesn't give any clear distinction between them.

I would be happy to know the differences between the functions.

theo-brown
  • 653
  • 10
  • 27
Gulzar
  • 23,452
  • 27
  • 113
  • 201
  • if you are interested in variable length nested lists to tensor here are links that seems to have a solution: https://stackoverflow.com/questions/55050717/converting-list-of-tensors-to-tensors-pytorch and https://discuss.pytorch.org/t/nested-list-of-variable-length-to-a-tensor/38699/21 – Charlie Parker Feb 09 '21 at 20:04

4 Answers4

213

stack

Concatenates sequence of tensors along a new dimension.

cat

Concatenates the given sequence of seq tensors in the given dimension.

So if A and B are of shape (3, 4):

  • torch.cat([A, B], dim=0) will be of shape (6, 4)
  • torch.stack([A, B], dim=0) will be of shape (2, 3, 4)
Mateen Ulhaq
  • 24,552
  • 19
  • 101
  • 135
Jatentaki
  • 11,804
  • 4
  • 41
  • 37
  • 20
    Thus, torch.stack([A,B],dim = 0) is equivalent to torch.cat([A.unsqueeze(0),b.unsqueeze(0)],dim = 0) . So if you find yourself doing many unsqueeze() operations before combining tensors you can likely simplify your code using stack(). – DerekG Apr 06 '20 at 16:46
  • 3
    Just to complement, in the OpenAI examples in the question, `torch.stack` and `torch.cat` can be used interchangeably in either code line since `torch.stack(tensors).sum() == torch.cat(tensors).sum()`. – user118967 Aug 06 '20 at 01:31
36
t1 = torch.tensor([[1, 2],
                   [3, 4]])

t2 = torch.tensor([[5, 6],
                   [7, 8]])
torch.stack torch.cat
'Stacks' a sequence of tensors along a new dimension:

enter image description here



'Concatenates' a sequence of tensors along an existing dimension:

enter image description here

These functions are analogous to numpy.stack and numpy.concatenate.

iacob
  • 20,084
  • 6
  • 92
  • 119
3

The original answer lacks a good example that is self-contained so here it goes:

import torch

# stack vs cat

# cat "extends" a list in the given dimension e.g. adds more rows or columns

x = torch.randn(2, 3)
print(f'{x.size()}')

# add more rows (thus increasing the dimensionality of the column space to 2 -> 6)
xnew_from_cat = torch.cat((x, x, x), 0)
print(f'{xnew_from_cat.size()}')

# add more columns (thus increasing the dimensionality of the row space to 3 -> 9)
xnew_from_cat = torch.cat((x, x, x), 1)
print(f'{xnew_from_cat.size()}')

print()

# stack serves the same role as append in lists. i.e. it doesn't change the original
# vector space but instead adds a new index to the new tensor, so you retain the ability
# get the original tensor you added to the list by indexing in the new dimension
xnew_from_stack = torch.stack((x, x, x, x), 0)
print(f'{xnew_from_stack.size()}')

xnew_from_stack = torch.stack((x, x, x, x), 1)
print(f'{xnew_from_stack.size()}')

xnew_from_stack = torch.stack((x, x, x, x), 2)
print(f'{xnew_from_stack.size()}')

# default appends at the from
xnew_from_stack = torch.stack((x, x, x, x))
print(f'{xnew_from_stack.size()}')

print('I like to think of xnew_from_stack as a \"tensor list\" that you can pop from the front')

output:

torch.Size([2, 3])
torch.Size([6, 3])
torch.Size([2, 9])
torch.Size([4, 2, 3])
torch.Size([2, 4, 3])
torch.Size([2, 3, 4])
torch.Size([4, 2, 3])
I like to think of xnew_from_stack as a "tensor list"

for reference here are the definitions:

cat: Concatenates the given sequence of seq tensors in the given dimension. The consequence is that a specific dimension changes size e.g. dim=0 then you are adding elements to the row which increases the dimensionality of the column space.

stack: Concatenates sequence of tensors along a new dimension. I like to think of this as the torch "append" operation since you can index/get your original tensor by "poping it" from the front. With no arguments, it appends tensors to the front of the tensor.


Related:


Update: With nested list of the same size

def tensorify(lst):
    """
    List must be nested list of tensors (with no varying lengths within a dimension).
    Nested list of nested lengths [D1, D2, ... DN] -> tensor([D1, D2, ..., DN)

    :return: nested list D
    """
    # base case, if the current list is not nested anymore, make it into tensor
    if type(lst[0]) != list:
        if type(lst) == torch.Tensor:
            return lst
        elif type(lst[0]) == torch.Tensor:
            return torch.stack(lst, dim=0)
        else:  # if the elements of lst are floats or something like that
            return torch.tensor(lst)
    current_dimension_i = len(lst)
    for d_i in range(current_dimension_i):
        tensor = tensorify(lst[d_i])
        lst[d_i] = tensor
    # end of loop lst[d_i] = tensor([D_i, ... D_0])
    tensor_lst = torch.stack(lst, dim=0)
    return tensor_lst

here is a few unit tests (I didn't write more tests but it worked with my real code so I trust it's fine. Feel free to help me by adding more tests if you want):


def test_tensorify():
    t = [1, 2, 3]
    print(tensorify(t).size())
    tt = [t, t, t]
    print(tensorify(tt))
    ttt = [tt, tt, tt]
    print(tensorify(ttt))

if __name__ == '__main__':
    test_tensorify()
    print('Done\a')
iacob
  • 20,084
  • 6
  • 92
  • 119
Charlie Parker
  • 5,884
  • 57
  • 198
  • 323
1

If someone is looking into the performance aspects of this, I've done a small experiment. In my case, I needed to convert a list of scalar tensors into a single tensor.

import torch
torch.__version__ # 1.10.2
x = [torch.randn(1) for _ in range(10000)]
torch.cat(x).shape, torch.stack(x).shape # torch.Size([10000]), torch.Size([10000, 1])

%timeit torch.cat(x) # 1.5 ms ± 476 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%timeit torch.cat(x).reshape(-1,1) # 1.95 ms ± 534 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%timeit torch.stack(x) # 5.36 ms ± 643 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

My conclusion is that even if you want to have the additional dimension of torch.stack, using torch.cat and then reshape is better.

Note: this post is taken from the PyTorch forum (I am the author of the original post)

  • That's interesting. Can you explain why that happens? My guess is cache or memory rows are in the `cat` direction and not in the `stack` direction. – Gulzar May 09 '22 at 09:34