I have tried to use run the official Flax/ImageNet code. Due to the difference in jax version, I have tried two methods. The first example I downgrade jax and jaxlib 0.3.25, and in the second example I change jax and jaxlib to 0.4.4. Then I setup TPU as
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
import jax
jax.local_devices()
which also shows that the TPU are properly set up.
However, as I continue run the code, both version gives this error.
/usr/local/lib/python3.9/dist-packages/flax/__init__.py in <module>
20 )
21
---> 22 from . import core
23 from . import jax_utils
24 from . import linen
/usr/local/lib/python3.9/dist-packages/flax/core/__init__.py in <module>
14
15 from .axes_scan import broadcast as broadcast
---> 16 from .frozen_dict import (
17 FrozenDict as FrozenDict,
18 freeze as freeze,
/usr/local/lib/python3.9/dist-packages/flax/core/frozen_dict.py in <module>
48
49
---> 50 @jax.tree_util.register_pytree_with_keys_class
51 class FrozenDict(Mapping[K, V]):
52 """An immutable variant of the Python dict."""
AttributeError: module 'jax.tree_util' has no attribute 'register_pytree_with_keys_class'
I am really not sure where this comes from, and how I can make this work.