I want to setup and pre-process a gym environment with stable baselines 3 and then write my own code for the agent but I have trouble making sense of stable baselines documentation. This leads to not fully understanding what I'm doing and errors with shapes.
I'm using the following to create the environments and stack them in order to process 4 of these simultaneously:
env = make_atari_env('AlienDeterministic-v4', n_envs=4)
env = VecFrameStack(env, n_stack=4)
my understanding is that make_atari_env performs all standard processing (i.e. rescale, greyscale, clip rewards/error, frame skipping etc) and that VecFrameStack stacks the enviroments. However, I suppose this does not transpose the image to c h w? If not what is the correct way of setting it up?
I'm still testing and have this network:
class Network(nn.Module):
def __init__(self, env):
super().__init__()
self.num_actions = env.action_space.n
in_features = int(np.prod(env.observation_space.shape))
self.net = nn.Sequential(
nn.Flatten(),
nn.Linear(in_features, 64),
nn.Tanh(),
nn.Linear(64, env.action_space.n) # equal to the number of possible actions
)
def forward(self, x):
return self.net(x)
def act(self, obses, epsilon):
obses_t = torch.as_tensor(obses, dtype=torch.float32)
q_values = self(obses_t)
max_q_indeces = torch.argmax(q_values, dim=1)
actions = max_q_indeces.detach().tolist()
for i in range(len(actions)):
rnd_sample = random.random()
if rnd_sample <= epsilon:
actions[i] = random.randint(0, self.num_actions - 1)
return actions
................
along with other code for training and I get the following error:
Step 0
Avg Rew 0.0
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-158-3bfd756ab672> in <module>
35 rnd_sample = random.random()
36
---> 37 actions = online_net.act(obses, epsilon)
38
39 new_obses, rews, dones, _ = env.step(actions)
6 frames
<ipython-input-157-fd78ebd7fe6f> in act(self, obses, epsilon)
19 def act(self, obses, epsilon):
20 obses_t = torch.as_tensor(obses, dtype=torch.float32)
---> 21 q_values = self(obses_t)
22
23 max_q_indeces = torch.argmax(q_values, dim=1)
/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
<ipython-input-157-fd78ebd7fe6f> in forward(self, x)
15
16 def forward(self, x):
---> 17 return self.net(x)
18
19 def act(self, obses, epsilon):
/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
/usr/local/lib/python3.8/dist-packages/torch/nn/modules/container.py in forward(self, input)
202 def forward(self, input):
203 for module in self:
--> 204 input = module(input)
205 return input
206
/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
/usr/local/lib/python3.8/dist-packages/torch/nn/modules/linear.py in forward(self, input)
112
113 def forward(self, input: Tensor) -> Tensor:
--> 114 return F.linear(input, self.weight, self.bias)
115
116 def extra_repr(self) -> str:
RuntimeError: mat1 and mat2 shapes cannot be multiplied (84x336 and 28224x64)
But I don't understand how to use stable baselines to get the correct shape (input). Any pointers or help is appreciated. Thanks