0

I'm running a TPU v3-8 VM on Google. On the VM, I installed jax with pip install "jax[tpu]==0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html.

Unfortunately, I'm getting the message No GPU/TPU found, falling back to CPU, when issuing jax.device_count(). The same holds for pip install jax==0.2.12. Only when I'm using pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html (newest jax version), it works. But I need jax version 0.2.12 or 0.2.16 because I would like to train GPT-J on a TPU following the tutorial https://github.com/kingoflolz/mesh-transformer-jax/blob/master/howto_finetune.md

How can I get it running with these versions?

BlackHawk
  • 719
  • 1
  • 6
  • 18
  • Hi @BlackHawk, Can you try the commands mentioned in this [github link](https://github.com/google/jax/discussions/10323)? Let me know if this resolves your issue. – Shipra Sarkar Nov 20 '22 at 14:51
  • Thank you very much for your help. I have first issued pip install `jax[tpu]==0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html` and then `pip install -U jaxlib==0.1.68+cuda111 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html` but when I then try to `import jax`, I'm getting the error `Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/usr/local/lib` – BlackHawk Nov 21 '22 at 17:50
  • Btw. I would like to run it on TPU not GPU. – BlackHawk Nov 21 '22 at 18:03
  • Hi @BlackHawk, Can you try by setting TF_CPP_MIN_LOG_LEVEL=0 while using the jax version 0.2.12 or try training GPT-J using the latest version of jax? Let me know if these steps help. – Shipra Sarkar Nov 22 '22 at 14:27

1 Answers1

2

Could you please try to explicitly set TPU_LIBRARY_PATH to the present location of the libtpu.so? most likely /home/<your username>/.local/lib/python3.8/site-packages/libtpu/libtpu.so

Here is the relevant GitHub issue: https://github.com/google/jax/issues/13321

As mentioned there, " The underlying problem is that this version of jax still expected libtpu.so to be automatically installed in the VM image (https://github.com/google/jax/blob/jax-v0.2.16/jax/_src/cloud_tpu_init.py#L104), which the TPU VM base image no longer does. "