Edit:
It seems there is an out of the box to do this by implementing a MaskSplitterNetwork
:
Assuming the filter function is in the form of:
def filter_fun(observation):
return observation['observation'], observation['legal_moves']
Create the actor network (and if required, the value network) and wrap it unsing the MaskSplitterNetwork constructor:
masked_actor_network = mask_splitter_network.MaskSplitterNetwork(
splitter_fn=filter_fun,
wrapped_network=actor_distribution_network.ActorDistributionNetwork(
train_env.observation_spec()['observation'],
train_env.action_spec(),
fc_layer_params=fc_layer_params
),
passthrough_mask=True
)
And feed the masked actor network into the reinforce agent
agent = reinforce_agent.ReinforceAgent(
train_env.time_step_spec(),
train_env.action_spec(),
actor_network=masked_actor_network,
optimizer=optimizer,
normalize_returns=True,
train_step_counter=train_step_counter,
)