I am trying to evaluate the performance of a trained DQN model with the Deep Q Network
` device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
class DeepQNetwork(nn.Module):
def __init__(self, lr, n_actions, name, input_dims, chkpt_dir):
super(DeepQNetwork, self).__init__()
self.checkpoint_dir = chkpt_dir
self.checkpoint_file = os.path.join(self.checkpoint_dir, name)
# you may want to play around with this and forward()
self.fc1 = nn.Linear(input_dims[0], 64)
self.fc2 = nn.Linear(64, 64)
self.fc3 = nn.Linear(64, n_actions)
self.optimizer = optim.RMSprop(self.parameters(), lr=lr)
self.loss = nn.MSELoss()
self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
self.to(self.device)
# you may want to play around with this
def forward(self, state):
flat1 = F.relu(self.fc1(state))
flat2 = F.relu(self.fc2(flat1))
actions = self.fc3(flat2)
return actions
def save_checkpoint(self):
print('... saving checkpoint ...')
T.save(self.state_dict(), self.checkpoint_file)
def load_checkpoint(self):
print('... loading checkpoint ...')
self.load_state_dict(T.load(self.checkpoint_file))`
This model is loaded in another file which uses the functions to load it
def use_bline(self):
self.agent = DQNAgent(chkpt_dir="../Models_DQN/model_b_line", algo='DQNAgent', env_name='Scenario1b')
# needed to get the pytorch checkpoint
self.agent.load_models()
self.agent_name = "Bline"
And when running the evaluation file, I get this error
RuntimeError: Error(s) in loading state_dict for DeepQNetwork: size mismatch for fc3.weight: copying a param with shape torch.Size([54, 64]) from checkpoint, the shape in current model is torch.Size([41, 64]). size mismatch for fc3.bias: copying a param with shape torch.Size([54]) from checkpoint, the shape in current model is torch.Size([41]).
I have saved the trained model and trying to evaluated the same trained model using my Laptops CPU.
Let me know if you require more information, this may be a simple problem for which I have not been able to find a solution for. Any help or direction will be very (VERY) appreciated.
I tried using this link but got the same error