0

I want to setup and pre-process a gym environment with stable baselines 3 and then write my own code for the agent but I have trouble making sense of stable baselines documentation. This leads to not fully understanding what I'm doing and errors with shapes.

I'm using the following to create the environments and stack them in order to process 4 of these simultaneously:

env = make_atari_env('AlienDeterministic-v4', n_envs=4)
env = VecFrameStack(env, n_stack=4)

my understanding is that make_atari_env performs all standard processing (i.e. rescale, greyscale, clip rewards/error, frame skipping etc) and that VecFrameStack stacks the enviroments. However, I suppose this does not transpose the image to c h w? If not what is the correct way of setting it up?

I'm still testing and have this network:

class Network(nn.Module):
  def __init__(self, env):
    super().__init__()

    self.num_actions = env.action_space.n

    in_features = int(np.prod(env.observation_space.shape))

    self.net = nn.Sequential(
        nn.Flatten(),
        nn.Linear(in_features, 64),
        nn.Tanh(),
        nn.Linear(64, env.action_space.n) # equal to the number of possible actions
    )

  def forward(self, x):
    return self.net(x)

  def act(self, obses, epsilon):
    obses_t = torch.as_tensor(obses, dtype=torch.float32)
    q_values = self(obses_t)

    max_q_indeces = torch.argmax(q_values, dim=1)
    actions = max_q_indeces.detach().tolist()

    for i in range(len(actions)):
      rnd_sample = random.random()
      if rnd_sample <= epsilon:
        actions[i] = random.randint(0, self.num_actions - 1)

    return actions
 ................

along with other code for training and I get the following error:

Step 0
Avg Rew 0.0
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-158-3bfd756ab672> in <module>
     35   rnd_sample = random.random()
     36 
---> 37   actions = online_net.act(obses, epsilon)
     38 
     39   new_obses, rews, dones, _ = env.step(actions)

6 frames
<ipython-input-157-fd78ebd7fe6f> in act(self, obses, epsilon)
     19   def act(self, obses, epsilon):
     20     obses_t = torch.as_tensor(obses, dtype=torch.float32)
---> 21     q_values = self(obses_t)
     22 
     23     max_q_indeces = torch.argmax(q_values, dim=1)

/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1192         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194             return forward_call(*input, **kwargs)
   1195         # Do not call functions when jit is used
   1196         full_backward_hooks, non_full_backward_hooks = [], []

<ipython-input-157-fd78ebd7fe6f> in forward(self, x)
     15 
     16   def forward(self, x):
---> 17     return self.net(x)
     18 
     19   def act(self, obses, epsilon):

/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1192         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194             return forward_call(*input, **kwargs)
   1195         # Do not call functions when jit is used
   1196         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.8/dist-packages/torch/nn/modules/container.py in forward(self, input)
    202     def forward(self, input):
    203         for module in self:
--> 204             input = module(input)
    205         return input
    206 

/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1192         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194             return forward_call(*input, **kwargs)
   1195         # Do not call functions when jit is used
   1196         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.8/dist-packages/torch/nn/modules/linear.py in forward(self, input)
    112 
    113     def forward(self, input: Tensor) -> Tensor:
--> 114         return F.linear(input, self.weight, self.bias)
    115 
    116     def extra_repr(self) -> str:

RuntimeError: mat1 and mat2 shapes cannot be multiplied (84x336 and 28224x64)

But I don't understand how to use stable baselines to get the correct shape (input). Any pointers or help is appreciated. Thanks

henrycmcjo
  • 93
  • 5

0 Answers0