3

I am trying to access multiple elements in a 3D-Pytorch-Tensor, but the number of elements that are returned is wrong.

This is my code:

import torch

a = torch.FloatTensor(4,3,2)
print("a = {}".format(a))
print("a[:][:][0] = {}".format(a[:][:][0]))

This is the output:

a = tensor([[[-4.8569e+36,  3.0760e-41],
         [ 2.7953e+20,  1.6928e+22],
         [ 3.1692e-40,  7.2945e-15]],

        [[ 2.5011e+24,  1.3173e-39],
         [ 1.7229e-07,  4.1262e-08],
         [ 4.1490e-08,  6.4103e-10]],

        [[ 3.1728e-40,  5.8258e-40],
         [ 2.8776e+32,  6.7805e-10],
         [ 3.1764e-40,  5.4229e+08]],

        [[ 7.2424e-37,  1.3697e+07],
         [-2.0362e-33,  1.8146e+11],
         [ 3.1836e-40,  1.9670e+34]]])
a[:][:][0] = tensor([[-4.8569e+36,  3.0760e-41],
        [ 2.7953e+20,  1.6928e+22],
        [ 3.1692e-40,  7.2945e-15]])

I would expect something like this:

a[:][:][0] = tensor([[-4.8569e+36,  2.7953e+20, 3.1692e-40, 
          2.5011e+24, 1.7229e-07, 4.1490e-08, 
          3.1728e-40, 2.8776e+32, 3.1764e-40, 
          7.2424e-37, -2.0362e-33, 3.1836e-40]])

Can anyone explain to me how I can come to this result? Thank you very much in advance!

I get exactly the expected result on performing:

for i in range(4):
   for j in range(3):
      print("a[{}][{}][0] = {}".format(i,j, a[i][j][0]))
kmario23
  • 57,311
  • 13
  • 161
  • 150
Anno
  • 761
  • 1
  • 10
  • 22
  • 1
    Thank you, but this does not work. For whatever reason ```a[0][:][:]```, ```a[:][0][:]```, ```a[:][:][0]``` and even ```a[:][:][:][:][0]``` output the same?!?! How are these things implemented?!?! – Anno May 04 '19 at 12:17
  • I've added detailed explanation and annotation. That should clarify :) – kmario23 May 04 '19 at 23:54

2 Answers2

2

Short answer, you need to use a[:, :, 0]

More explanation: When you do a[:] it returns a itself. So a[:][:][0] is same as doing a[0] which will give you the elements at the 0th position of the first axis (hence the size is (3,2)). What you want are elements from the 0th position of the last axis for which you need to do a[:, :, 0].

Umang Gupta
  • 15,022
  • 6
  • 48
  • 66
1

Here is some explanation and correct way to index the elements that you're looking for:

# input tensor to work with
In [11]: a = torch.arange(4*3*2).reshape(4,3,2)

# check its shape
In [12]: a.shape
Out[12]: torch.Size([4, 3, 2])

# inspect/annotate the tensor
In [13]: a
Out[13]:            # (    4      ,  3    ,    2      ) <= shape
tensor([[[ 0,  1],    | # block-0 | row-0 | col-0 col-1
         [ 2,  3],    | # block-0 | row-1 | col-0 col-1
         [ 4,  5]],   | # block-0 | row-2 | col-0 col-1

        [[ 6,  7],    | # block-1 | row-0 | col-0 col-1
         [ 8,  9],    | # block-1 | row-1 | col-0 col-1
         [10, 11]],   | # block-1 | row-2 | col-0 col-1

        [[12, 13],    | # block-2 | row-0 | col-0 col-1
         [14, 15],    | # block-2 | row-1 | col-0 col-1
         [16, 17]],   | # block-2 | row-2 | col-0 col-1

        [[18, 19],    | # block-3 | row-0 | col-0 col-1
         [20, 21],    | # block-3 | row-1 | col-0 col-1
         [22, 23]]])  | # block-3 | row-2 | col-0 col-1


# slice out what we need; (in all blocks, all rows, column-0)
In [14]: a[:, :, 0]
Out[14]: 
tensor([[ 0,  2,  4],
        [ 6,  8, 10],
        [12, 14, 16],
        [18, 20, 22]])

Explanation/Clarification:

The tensor has shape [4, 3, 2] where 4 represents the number of blocks (block-0, ... block-3). Next, we have 3 which represents the number of rows in each block. And finally, we've 2 which represent the number of columns in each row. We slice this using the slicing notation a[:, :, 0].

To access, the block, we'd need only one index (viz. a[0], ... a[3]). To access a specific row in a specific block, we'd need two indices (viz. a[0, 1], ... a[3,2]). To access a specific column of a specific row from a specific block, we'd need three indices (viz. a[0, 1, 1] etc.,)


I surmise that your case caused confusion because of using torch.FloatTensor(). The problem with using torch.FloatTensor() is that it'll allocate junk values or the values left by previous program which used those memory blocks. This could be puzzling to work with sometimes because we might get inconsistent results between subsequent runs.

kmario23
  • 57,311
  • 13
  • 161
  • 150