2

I have multiple deep neural networks in my model and want them to have the same input sizes (networks are of different classes). For example, my model is:

class Model:
 def __init__(self, cfg: DictConfig):
   self.net1 = Net1(**cfg.net1_hparams)
   self.net2 = Net2(**cfg.net2_hparams)

Here, Net1 and Net2 have different sets of hyper parameters, but among which the input_size parameter is shared between Net1 and Net2, and have to be matched, i.e., cfg.net1_hparams.input_size == cfg.net2_hparams.input_size.

I could define the input_size at the parent level: cfg.input_size and manually pass them to both Net1 and Net2. But, I want the hparams-configs of each Net's are complete so that later I can build Net1 only using the cfg.net1_hparams.

Is there a good way to achieve this in hydra?

nzer0
  • 315
  • 2
  • 10

1 Answers1

3

This can be achieved using OmegaConf's variable interpolation feature.

Here is a minimal example using variable interpolation with Hydra to achieve the desired result:

# config.yaml
shared_hparams:
  input_size: [128, 128]
net1_hparams:
  name: net one
  input_size: ${shared_hparams.input_size}
net2_hparams:
  name: net two
  input_size: ${shared_hparams.input_size}
"""my_app.py"""
import hydra
from omegaconf import DictConfig

class Model:
    def __init__(self, cfg: DictConfig):
        print("Net1", dict(**cfg.net1_hparams))
        print("Net2", dict(**cfg.net2_hparams))

@hydra.main(config_name="config")
def my_app(cfg: DictConfig) -> None:
    Model(cfg)

if __name__ == "__main__":
    my_app()

Running my_app.py at the command line produces this result:

$ python my_app.py
Net1 {'name': 'net one', 'input_size': [128, 128]}
Net2 {'name': 'net two', 'input_size': [128, 128]}
Jasha
  • 5,507
  • 2
  • 33
  • 44
  • Thank you for the detailed answer! Quick question. While the above code works well, when I print the cfg by `OmegaConf.to_yaml(cfg)`, I see non-interpolated results for input_size: `input_size: ${shared_hparams.input_size}`. Do you think this is the expected behavior? – nzer0 May 18 '21 at 09:57
  • 1
    You're welcome! Yes, that is expected: by default, `OmegaConf.to_yaml` does not resolve interpolations. Try using the `resolve` keyword argument: `OmegaConf.to_yaml(cfg, resolve=True)`. Also of interest to you may be `OmegaConf.to_container(cfg, resolve=True)`, which will convert the `DictConfig` object to a plain Python `dict`. – Jasha May 18 '21 at 21:19