I have launched a Google Cloud TPU VM instance and installed the latest version of JAX, but it cannot see my TPU. Following the instructions at https://cloud.google.com/tpu/docs/troubleshooting/trouble-jax I encounter the following:
>>> import jax
>>> jax.devices()
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[CpuDevice(id=0)]
>>> TF_CPP_MIN_LOG_LEVEL=0
>>> jax.devices()
[CpuDevice(id=0)]
All of the Google Search results I have seen for this error suggest installing JAX with CUDA support, but shouldn't that be unnecessary with TPUs?