0

I have tried to use run the official Flax/ImageNet code. Due to the difference in jax version, I have tried two methods. The first example I downgrade jax and jaxlib 0.3.25, and in the second example I change jax and jaxlib to 0.4.4. Then I setup TPU as

import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
import jax
jax.local_devices()

which also shows that the TPU are properly set up.

However, as I continue run the code, both version gives this error.

/usr/local/lib/python3.9/dist-packages/flax/__init__.py in <module>
     20 )
     21 
---> 22 from . import core
     23 from . import jax_utils
     24 from . import linen

/usr/local/lib/python3.9/dist-packages/flax/core/__init__.py in <module>
     14 
     15 from .axes_scan import broadcast as broadcast
---> 16 from .frozen_dict import (
     17   FrozenDict as FrozenDict,
     18   freeze as freeze,

/usr/local/lib/python3.9/dist-packages/flax/core/frozen_dict.py in <module>
     48 
     49 
---> 50 @jax.tree_util.register_pytree_with_keys_class
     51 class FrozenDict(Mapping[K, V]):
     52   """An immutable variant of the Python dict."""

AttributeError: module 'jax.tree_util' has no attribute 'register_pytree_with_keys_class'

I am really not sure where this comes from, and how I can make this work.

RanWang
  • 310
  • 2
  • 12

1 Answers1

2

Colab TPU is only compatible with JAX version 0.3.25 and older (see https://github.com/google/jax#pip-installation-colab-tpu), and flax versions more recent than 0.6.2 require a newer JAX version. If you want to use jax+flax on Colab TPU, you should install the following versions:

pip install jax==0.3.25 jaxlib==0.3.25 flax==0.6.2

It looks like the pip install step in the flax Colab TPU example needs to be updated; you might consider opening an issue at http://github.com/google/flax.

jakevdp
  • 77,104
  • 11
  • 125
  • 160