I have trained an A2C model (MlpPolicy) using stable-baselines3 (I am quite new to reinforcement learning and found this to be a good place to start). However, I now want to use a XRL (eXplainable Reinforcement Learning) method to understand the model better. I decided to use DeepSHAP as it has a nice implementation and because I am familiar with SHAP. DeepSHAP works on PyTorch, which is the underlying framework behind stable-baselines3. So my goal is to extract the underlying PyTorch model from the stable-baselines3 model. However, I am having some issues with this.
I have found the following thread: https://github.com/hill-a/stable-baselines/issues/372 This thread did help me a bit, however, because the architecture of A2C is different from the model used in this thread, I was not yet able to solve my problem.
From what I understand stable-baselines3 offers the option to export models using
model.policy.state_dict()
However, I am struggling to import what I have exported through that method.
When printing out
A2C_model.policy
I get a glimpse of what the structure of the PyTorch model looks like. Output:
ActorCriticPolicy(
(features_extractor): FlattenExtractor(
(flatten): Flatten(start_dim=1, end_dim=-1)
)
(pi_features_extractor): FlattenExtractor(
(flatten): Flatten(start_dim=1, end_dim=-1)
)
(vf_features_extractor): FlattenExtractor(
(flatten): Flatten(start_dim=1, end_dim=-1)
)
(mlp_extractor): MlpExtractor(
(policy_net): Sequential(
(0): Linear(in_features=49, out_features=64, bias=True)
(1): Tanh()
(2): Linear(in_features=64, out_features=64, bias=True)
(3): Tanh()
)
(value_net): Sequential(
(0): Linear(in_features=49, out_features=64, bias=True)
(1): Tanh()
(2): Linear(in_features=64, out_features=64, bias=True)
(3): Tanh()
)
)
(action_net): Linear(in_features=64, out_features=5, bias=True)
(value_net): Linear(in_features=64, out_features=1, bias=True)
)
I tried to recreate it myself but I am not fluent enough with PyTorch yet to get it work...
So my question is: how can I export the stable_baselines3 model to PyTorch?
I have tried re-building the model architecture in PyTorch according to the output of printing A2C_model.policy. My code is currently:
import torch as th
import torch.nn as nn
class PyTorchMlp(nn.Module):
def __init__(self):
nn.Module.__init__(self)
n_inputs = 49
n_actions = 5
self.features_extractor = nn.Flatten(start_dim = 1, end_dim = -1)
self.pi_features_extractor = nn.Flatten(start_dim = 1, end_dim = -1)
self.vf_features_extractor = nn.Flatten(start_dim = 1, end_dim = -1)
self.mlp_extractor = nn.Sequentail(
self.policy_net = nn.Sequential(
nn.Linear(in_features = n_inputs, out_features = 64),
nn.Tanh(),
nn.Linear(in_features = 64, out_features = 64),
nn.Tanh()
),
self.value_net = nn.Sequential(
nn.Linear(in_features = n_inputs, out_features = 64),
nn.Tanh(),
nn.Linear(in_features = 64, out_features = 64),
nn.Tanh()
)
)
self.action_net = nn.Linear(in_features = 64, out_features = 5)
self.value_net = nn.Linear(in_features = 64, out_features = 1)
def forward(self, x):
pass