0

This is linked to Cartesian product of nested dictionaries of lists

Suppose I have a nested dict with lists representing multiple configurations, like:

{'algorithm': ['PPO', 'A2C', 'DQN'],
'env_config': {'env': 'GymEnvWrapper-Atari',
'env_config': {'AtariEnv': {'game': ['breakout', 'pong']}}}

and the goal is to compute the cartesian product of the lists inside the nested dict to get all possible configurations.

This is what I got so far:

def product(*args, repeat=1, root=False):
    pools = [tuple(pool) for pool in args] * repeat
    result = [[]]
    for pool in pools:
        result = [x+[y] for x in result for y in pool]
    print("************************")
    print(root)
    for r in result:
        print(tuple(r))
    print("************************")
    for prod in result:
        yield tuple(prod)


def recursive_cartesian_product(dic, root=True):
    # based on https://stackoverflow.com/a/50606871/11051330
    # added differentiation between list and entry to protect strings in dicts
    # with uneven depth
    keys, values = dic.keys(), dic.values()

    vals = (recursive_cartesian_product(v, False) if isinstance(v, dict)
            else v if isinstance(v, list) else (v,) for v in
            values)

    print("!", root)
    for conf in product(*vals, root=root):
        print(conf)
        yield dict(zip(keys, conf))

And here is the relevant output:

************************
True
('PPO', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'breakout'}}})
('PPO', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}})
('A2C', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'breakout'}}})
('A2C', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}})
('DQN', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'breakout'}}})
('DQN', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}})
************************
('PPO', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'breakout'}}})
('PPO', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}})
('A2C', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}})
('A2C', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}})
('DQN', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}})
('DQN', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}})

Notice how the print-statement inside product works correctly, while the print inside yield fails and does not vary the env value for the later configs.

jonrsharpe
  • 115,751
  • 26
  • 228
  • 437

3 Answers3

0

Turns out that the problem was not inside the above function, but outside of it. The resulting confs were passed to a function as **kwargs, which messed up the generator.

Here's a quick solution:

def recursive_cartesian_product(dic):
    # based on https://stackoverflow.com/a/50606871/11051330
    # added differentiation between list and entry to protect strings
    # yield contains deepcopy. important as use otherwise messes up generator
    keys, values = dic.keys(), dic.values()

    vals = (recursive_cartesian_product(v) if isinstance(v, dict)
            else v if isinstance(v, list) else (v,) for v in
            values)

    for conf in itertools.product(*vals):
        yield deepcopy(dict(zip(keys, conf)))
0

itertools has a product type already:

from itertools import product


d = {'algorithm': ['PPO', 'A2C', 'DQN'],
     'env_config': {'env': 'GymEnvWrapper-Atari',
                    'env_config': {'AtariEnv': {'game': ['breakout', 'pong']}}}

for algo, game in product(d['algorithm'],
                          d['env_config']['env_config']['AtariEnv']['game']):
    print((algo, {'env': 'GymEnvWrapper-Atari', 
                  'env_config': {'AtariEnv': {'game': game}}})) 
chepner
  • 497,756
  • 71
  • 530
  • 681
  • Yes, I copied the reimplementation from itertools for debugging purposes. The problem was elsewhere, see my own answer below. – DisplayName Jun 21 '21 at 16:00
0

Using itertools.product is really simpler than rolling your own.

If you don't expect your env_config to change (except for game names), there is no need to implement a generic recursive dict visitor.
So you only want the product of algorithms with game names, always using the AtariEnv then :

from itertools import product

possible_configurations = {'algorithm': ['PPO', 'A2C', 'DQN'],
'env_config': {'env': 'GymEnvWrapper-Atari',
'env_config': {'AtariEnv': {'game': ['breakout', 'pong']}}}}

algorithms = tuple(possible_configurations["algorithm"])
games = tuple(
    {"env": "GymEnvWrapper-Atari", "env_config": {"AtariEnv": {"game": game_name}}}
    for game_name in possible_configurations["env_config"]["env_config"]["AtariEnv"]["game"]
)

factors = (algorithms, games)
for config in product(*factors):
    print(config)

If you prefer a general solution, here is mine :

from itertools import product

possible_configurations = {'algorithm': ['PPO', 'A2C', 'DQN'],
'env_config': {'env': 'GymEnvWrapper-Atari',
'env_config': {'AtariEnv': {'game': ['breakout', 'pong']}}}}


def product_visitor(obj):
    if isinstance(obj, dict):
        yield from (
            dict(possible_product)
            for possible_product in product(
                *(
                    [(key, possible_value) for possible_value in product_visitor(value)]
                    for key, value in obj.items())))
    elif isinstance(obj, list):
        for value in obj:
            yield from product_visitor(value)
    else:  # either a string, a number, a boolean or null (all scalars)
        yield obj


configs = tuple(product_visitor(possible_configurations))
print("\n".join(map(str, configs)))
assert configs == (
    {'algorithm': 'PPO', 'env_config': {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'breakout'}}}},
    {'algorithm': 'PPO', 'env_config': {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}}},
    {'algorithm': 'A2C', 'env_config': {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'breakout'}}}},
    {'algorithm': 'A2C', 'env_config': {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}}},
    {'algorithm': 'DQN', 'env_config': {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'breakout'}}}},
    {'algorithm': 'DQN', 'env_config': {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}}},
)
Lenormju
  • 4,078
  • 2
  • 8
  • 22