0

I used Curiosity exploration in rllib, but some errors occurred.

Unexpected key(s) in state_dict: "_curiosity_feature_net.post_fc_stack._value_branch._model.0.weight", "_curiosity_feature_net.post_fc_stack._value_branch._model.0.bias", "_curiosity_feature_net.logits_layer._model.0.weight", "_curiosity_feature_net.logits_layer._model.0.bias", "_curiosity_feature_net.value_layer._model.0.weight", "_curiosity_feature_net.value_layer._model.0.bias", "_curiosity_inverse_fcnet.0._model.0.weight", "_curiosity_inverse_fcnet.0._model.0.bias", "_curiosity_inverse_fcnet.1._model.0.weight", "_curiosity_inverse_fcnet.1._model.0.bias", "_curiosity_forward_fcnet.0._model.0.weight", "_curiosity_forward_fcnet.0._model.0.bias", "_curiosity_forward_fcnet.1._model.0.weight", "_curiosity_forward_fcnet.1._model.0.bias".

I've configured it at Trainer.config["exploration_config"] "exploration_config": { "type": "Curiosity", # <- Use the Curiosity module for exploring. "eta": 1.0, # Weight for intrinsic rewards before being added to extrinsic ones. "lr": 0.001, # Learning rate of the curiosity (ICM) module. "feature_dim": 288, # Dimensionality of the generated feature vectors. # Setup of the feature net (used to encode observations into feature (latent) vectors). "feature_net_config": { "fcnet_hiddens": [], "fcnet_activation": "relu", }, "inverse_net_hiddens": [256], # Hidden layers of the "inverse" model. "inverse_net_activation": "relu", # Activation of the "inverse" model. "forward_net_hiddens": [256], # Hidden layers of the "forward" model. "forward_net_activation": "relu", # Activation of the "forward" model. "beta": 0.2, # Weight for the "forward" loss (beta) over the "inverse" loss (1.0 - beta). # Specify, which exploration sub-type to use (usually, the algo's "default" # exploration, e.g. EpsilonGreedy for DQN, StochasticSampling for PG/SAC). "sub_exploration": { "type": "StochasticSampling", } },

And this is where the error occurred.

File "/home/zhangzheng/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1494, in load_state_dict

if len(error_msgs) > 0: raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( self.__class__.__name__, "\n\t".join(error_msgs)))

pikaz
  • 1
  • 1

0 Answers0