1

When creating a DQN agents with TFAGENTS it's possible to specify a function to mask valid/invalid actions.

This is done by specifying the observation_and_action_constraint_splitter function.

Apparently it's not possible to do the same for a REINFORCE agent.

How can I mask valid/invalid actions when using REINFORCE agents?

MarcoM
  • 1,093
  • 9
  • 25

1 Answers1

1

Edit:

It seems there is an out of the box to do this by implementing a MaskSplitterNetwork:

Assuming the filter function is in the form of:

def filter_fun(observation):
    return observation['observation'], observation['legal_moves']

Create the actor network (and if required, the value network) and wrap it unsing the MaskSplitterNetwork constructor:

masked_actor_network = mask_splitter_network.MaskSplitterNetwork(
    splitter_fn=filter_fun,
    wrapped_network=actor_distribution_network.ActorDistributionNetwork(
        train_env.observation_spec()['observation'],
        train_env.action_spec(),
        fc_layer_params=fc_layer_params
    ),
    passthrough_mask=True
)

And feed the masked actor network into the reinforce agent

agent = reinforce_agent.ReinforceAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    actor_network=masked_actor_network,
    optimizer=optimizer,
    normalize_returns=True,
    train_step_counter=train_step_counter,
)
jcaliz
  • 3,891
  • 2
  • 9
  • 13