0

I need to combine 4 tensors, representing greyscale images, of size [1,84,84], into a stack of shape [4,84,84], representing four greyscale images with each image represented as a "channel" in tensor style CxWxH.

I am using PyTorch.

I've tried using torch.stack and torch.cat but if one of these is the solution, I am not having luck figuring out the correct prep/methodology to get my results.

Thank you for any help.

import torchvision.transforms as T

class ReplayBuffer:
    def __init__(self, buffersize, batchsize, framestack, device, nS):
        self.buffer = deque(maxlen=buffersize)
        self.phi = deque(maxlen=framestack)
        self.batchsize = batchsize
        self.device = device

        self._initialize_stack(nS)

    def get_stack(self):
        #t =  torch.cat(tuple(self.phi),dim=0)
        t =  torch.stack(tuple(self.phi),dim=0)
        return t

    def _initialize_stack(self, nS):
        while len(self.phi) < self.phi.maxlen:
            self.phi.append(torch.tensor([1,nS[1], nS[2]]))

a = ReplayBuffer(buffersize=50000, batchsize=64, framestack=4, device='cuda', nS=[1,84,84])
print(a.phi)
s = a.get_stack()
print(s, s.shape)

The above code returns:

print(a.phi)

deque([tensor([ 1, 84, 84]), tensor([ 1, 84, 84]), tensor([ 1, 84, 84]), tensor([ 1, 84, 84])], maxlen=4)

print(s, s.shape)

tensor([[ 1, 84, 84],
        [ 1, 84, 84],
        [ 1, 84, 84],
        [ 1, 84, 84]]) torch.Size([4, 3])

But what I would like is the return to simply be [4, 84, 84]. I suspect this is quite simple but it's escaping me.

1 Answers1

0

It seems you have misunderstood what torch.tensor([1, 84, 84]) is doing. Let's take a look:

torch.tensor([1, 84, 84])
print(x, x.shape) #tensor([ 1, 84, 84]) torch.Size([3])

You can see from the example above, it gives you a tensor with only one dimension.

From your problem statement, you need a tensor of shape [1,84,84]. Here's how it look like:

from collections import deque
import torch
import torchvision.transforms as T

class ReplayBuffer:
    def __init__(self, buffersize, batchsize, framestack, device, nS):
        self.buffer = deque(maxlen=buffersize)
        self.phi = deque(maxlen=framestack)
        self.batchsize = batchsize
        self.device = device

        self._initialize_stack(nS)

    def get_stack(self):
        t =  torch.cat(tuple(self.phi),dim=0)
#         t =  torch.stack(tuple(self.phi),dim=0)
        return t

    def _initialize_stack(self, nS):
        while len(self.phi) < self.phi.maxlen:
#             self.phi.append(torch.tensor([1,nS[1], nS[2]]))
            self.phi.append(torch.zeros([1,nS[1], nS[2]]))

a = ReplayBuffer(buffersize=50000, batchsize=64, framestack=4, device='cuda', nS=[1,84,84])
print(a.phi)
s = a.get_stack()
print(s, s.shape)

Note that torch.cat gives you a tensor of shape [4, 84, 84] and torch.stack gives you a tensor of shape [4, 1, 84, 84]. Their difference can be found at What's the difference between torch.stack() and torch.cat() functions?

keineahnung2345
  • 2,635
  • 4
  • 13
  • 28
  • Thank you so much! I knew it was going to be something fairly straightforward but I can't say I'd ever have managed to determine that bug with my current knowledge. I'll continue to study tensor creation/handling. Cheers. – White_Rabbit.obj Feb 12 '19 at 06:56
  • @White_Rabbit.obj If this answer helps you, please consider accept it, thanks! – keineahnung2345 Feb 12 '19 at 07:05