0

In Hydra I can automatically instantiate my classes, e.g.

_target_: ClassA
foo: bar
model:
  _target_: ClassB
  hello: world

which results in recursively instantiated classes like where "inner class" instances are passed to "outer classes":

ClassA(
   foo="bar"
   model=ClassB(
      hello="world"
   )
)

Is there a way for ClassA to get the original configuration struct for ClassB, so that I have both the instance ClassB as well as the originating struct, e.g. a model and its hyperparameters

model:
  _target_: ClassB
  hello: world
Nils Werner
  • 34,832
  • 7
  • 76
  • 98

1 Answers1

0

You can achieve this with a combination of OmegaConf's variable interpolation with its custom resolvers.

Here is an example where ClassA.__init__ receives both model and modelconf as arguments. The argument modelconf is a copy of cfg.model that has had it's _target_ field removed.

# conf.yaml
_target_: app.ClassA
foo: bar
model:
  _target_: app.ClassB
  hello: world
modelconf: "${remove_target: ${.model}}"
# app.py
import hydra
from omegaconf import DictConfig, OmegaConf


class ClassB:
    def __init__(self, hello: str):
        print(f"ClassB.__init__ got {hello=}")


class ClassA:
    def __init__(self, foo: str, model: ClassB, modelconf: DictConfig) -> None:
        print(f"ClassA.__init__ got {foo=}, {model=}, {modelconf=}")


def remove_target_impl(conf: DictConfig) -> DictConfig:
    """Return a copy of `conf` with its `_target_` field removed."""
    conf = conf.copy()
    conf.pop("_target_")
    return conf


OmegaConf.register_new_resolver(
    "remove_target", resolver=remove_target_impl, replace=True
)


@hydra.main(config_path=".", config_name="conf.yaml")
def main(cfg: DictConfig) -> None:
    hydra.utils.instantiate(cfg)


if __name__ == "__main__":
    main()

At the command line:

$ python3 app.py
  ret = run_job(
ClassB.__init__ got hello='world'
ClassA.__init__ got foo='bar', model=<app.ClassB object at 0x10a125ee0>, modelconf={'hello': 'world'}

The motivation for removing the _target_ field from modelconf is as follows: If the _target_ key were not removed, then the call to instantiate would result in modelconf being mapped to an instance of ClassB (rather than to a DictConfig).

See also the _convert_ parameter to instantiate. Using _convert_=="partial" or _convert_=="all" would mean that ClassA.__init__ will receive a modelconf argument of type dict rather than of type DictConfig.

def remove_target_and_add_convert_impl(conf: DictConfig) -> DictConfig:
    conf = conf.copy()
    conf.pop("_target_")
    conf["_convert_"] = "all"  # instantiate should now result in a dict rather than in a DictConfig
    return conf
Jasha
  • 5,507
  • 2
  • 33
  • 44