You can install the python package Jax with some extra packages depending on your environment.
For GPU:
pip install jax[cuda] --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
For TPU:
pip install jax[tpu] --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html
How do I add those --find-links
URLs to the following pyproject.toml
?
[build-system]
requires = ["setuptools>=67.6.0"]
build-backend = "setuptools.build_meta"
[project]
name = "minimal_example"
version = '0.0.1'
requires-python = ">=3.9"
dependencies = [
"seqio-nightly[gcp,cache-tasks]",
"t5[gcp]",
"t5x @ git+https://github.com/google-research/t5x.git"
]
[project.optional-dependencies]
cpu = ["jax[cpu]"]
gpu = ["jax[cuda]" , "t5x[gpu] @ git+https://github.com/google-research/t5x.git"]
tpu = ["jax[tpu]", "t5x[tpu] @ git+https://github.com/google-research/t5x.git"]
dev = ["pytest", "mkdocs"]
If I do pip install -e .
then I get a working install.
But doing a pip install -e ".[gpu]
gives me a ResolutionImpossible
error.
And doing pip install -e ".[tpu]
gives me:
Packages installed from PyPI cannot depend on packages which are not also hosted on PyPI.
jax depends on libtpu-nightly@ https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20210615-py3-none-any.whl