1

I would like to use python datatypes - both built-in and imported from libraries such as numpy, tensorflow, etc - as arguments in my hydra configuration. Something like:

# config.yaml

arg1: np.float32
arg2: tf.float16

I'm currently doing this instead:

# config.yaml

arg1: 'float32'
arg2: 'float16
# my python code
# ...
DTYPES_LOOKUP = {
  'float32': np.float32,
  'float16': tf.float16
}
arg1 = DTYPES_LOOKUP[config.arg1]
arg2 = DTYPES_LOOKUP[config.arg2]

Is there a more hydronic/elegant solution?

pppery
  • 3,731
  • 22
  • 33
  • 46
miccio
  • 133
  • 1
  • 10

1 Answers1

0

Does the hydra.utils.get_class function solve this problem for you?

# config.yaml

arg1: numpy.float32  # note: use "numpy" here, not "np"
arg2: tensorflow.float16
# python code
...
from hydra.utils import get_class
arg1 = get_class(config.arg1)
arg2 = get_class(config.arg2)

Update 1: using a custom resolver

Based on miccio's comment below, here is a demonstration using an OmegaConf custom resolver to wrap the get_class function.

from omegaconf import OmegaConf
from hydra.utils import get_class

OmegaConf.register_new_resolver(name="get_cls", resolver=lambda cls: get_class(cls))

config = OmegaConf.create("""
# config.yaml

arg1: "${get_cls: numpy.float32}"
arg2: "${get_cls: tensorflow.float16}"
""")

arg1 = config.arg1
arg1 = config.arg2

Update 2:

It turns out that get_class("numpy.float32") succeeds but get_class("tensorflow.float16") raises a ValueError. The reason is that get_class checks that the returned value is indeed a class (using isinstance(cls, type)).

The function hydra.utils.get_method is slightly more permissive, checking only that the returned value is a callable, but this still does not work with tf.float16.

>>> isinstance(tf.float16, type)
False
>>> callable(tf.float16)
False

A custom resolver wrapping the tensorflow.as_dtype function might be in order.

>>> tf.as_dtype("float16")
tf.float16
Jasha
  • 5,507
  • 2
  • 33
  • 44
  • Thanks for the high-quality answer as usual! In this case, since I don't want my DNN model to "depend" on hydra, I think I will embed the get_class function inside a custom resolver. – miccio Jan 19 '22 at 07:03
  • Great idea r.e. the custom resolver! – Jasha Jan 19 '22 at 09:00
  • Updated the answer with a caveat regarding `tf.float16`. I've opened a related [feature request](https://github.com/facebookresearch/hydra/issues/1975). – Jasha Jan 19 '22 at 09:37
  • impressive work, thanks again!! – miccio Jan 19 '22 at 11:06