2

I'd like to compute a pairwise concatenation over a specific dimension in a batched manner.

For instance,

x = torch.tensor([[[0],[1],[2]],[[3],[4],[5]]])
x.shape = torch.Size([2, 3, 1])

I would like to get y such that y is the concatenation of all pairs of vectors across one dimension, ie:

y = torch.tensor([[[[0,0],[0,1],[0,2]],[[1,0],[1,1],[1,2]], [[2,0], [2,1], [2,2]]], 
                 [[[3,3],[3,4],[3,5]],[[4,3],[4,4],[4,5]], [[5,3],[5,4],[5,5]]]])

y.shape = torch.Size([2, 3, 3, 2])

So essentially, for each x[i,:], you generate all pairs of vectors and you concatenate them on the last dimension. Is there a straightforward way of doing that?

astiegler
  • 315
  • 3
  • 15
  • 2
    Very similar to [this](https://discuss.pytorch.org/t/create-all-possible-combinations-of-a-3d-tensor-along-the-dimension-number-1/48155). – Quang Hoang Mar 05 '21 at 16:41

2 Answers2

2

One possible way to do that would be:

    all_ordered_idx_pairs = torch.cartesian_prod(torch.tensor(range(x.shape[1])),torch.tensor(range(x.shape[1])))
    y = torch.stack([x[i][all_ordered_idx_pairs] for i in range(x.shape[0])])

After reshaping the tensor:

y = y.view(x.shape[0], x.shape[1], x.shape[1], -1)

you get:

y = torch.tensor([[[[0,0],[0,1],[0,2]],[[1,0],[1,1],[1,2]], [[2,0], [2,1], [2,2]]], 
                 [[[3,3],[3,4],[3,5]],[[4,3],[4,4],[4,5]], [[5,3],[5,4],[5,5]]]])
astiegler
  • 315
  • 3
  • 15
2

Without loops and using torch.arange(). The trick is to broadcast instead of using a for loop. That will apply the operation over all elements in the dimension with the : character. ​

x = torch.tensor([
    [[0.0000, 1.0000, 2.0000],
     [3.0000, 4.0000, 5.0000],
     [0.0000, -1.0000, -2.0000],
     [-3.0000, -4.0000, -5.0000]],
    [[0.0000, 10.0000, 20.0000],
     [30.0000, 40.0000, 50.0000],
     [0.0000, -10.0000, -20.0000],
     [-30.0000, -40.0000, -50.0000]
     ]
])
​
idx_pairs = torch.cartesian_prod(torch.arange(x.shape[1]), torch.arange(x.shape[1]))
y = x[:, idx_pairs].view(x.shape[0], x.shape[1], x.shape[1], -1)
tensor([[[[  0.,   1.,   2.,   0.,   1.,   2.],
          [  0.,   1.,   2.,   3.,   4.,   5.],
          [  0.,   1.,   2.,   0.,  -1.,  -2.],
          [  0.,   1.,   2.,  -3.,  -4.,  -5.]],
         [[  3.,   4.,   5.,   0.,   1.,   2.],
          [  3.,   4.,   5.,   3.,   4.,   5.],
          [  3.,   4.,   5.,   0.,  -1.,  -2.],
          [  3.,   4.,   5.,  -3.,  -4.,  -5.]],
         [[  0.,  -1.,  -2.,   0.,   1.,   2.],
          [  0.,  -1.,  -2.,   3.,   4.,   5.],
          [  0.,  -1.,  -2.,   0.,  -1.,  -2.],
          [  0.,  -1.,  -2.,  -3.,  -4.,  -5.]],
         [[ -3.,  -4.,  -5.,   0.,   1.,   2.],
          [ -3.,  -4.,  -5.,   3.,   4.,   5.],
          [ -3.,  -4.,  -5.,   0.,  -1.,  -2.],
          [ -3.,  -4.,  -5.,  -3.,  -4.,  -5.]]],
        [[[  0.,  10.,  20.,   0.,  10.,  20.],
          [  0.,  10.,  20.,  30.,  40.,  50.],
          [  0.,  10.,  20.,   0., -10., -20.],
          [  0.,  10.,  20., -30., -40., -50.]],
         [[ 30.,  40.,  50.,   0.,  10.,  20.],
          [ 30.,  40.,  50.,  30.,  40.,  50.],
          [ 30.,  40.,  50.,   0., -10., -20.],
          [ 30.,  40.,  50., -30., -40., -50.]],
         [[  0., -10., -20.,   0.,  10.,  20.],
          [  0., -10., -20.,  30.,  40.,  50.],
          [  0., -10., -20.,   0., -10., -20.],
          [  0., -10., -20., -30., -40., -50.]],
         [[-30., -40., -50.,   0.,  10.,  20.],
          [-30., -40., -50.,  30.,  40.,  50.],
          [-30., -40., -50.,   0., -10., -20.],
          [-30., -40., -50., -30., -40., -50.]]]])
nick
  • 21
  • 2