I have followed the objax documentation to install the library with GPU support: https://objax.readthedocs.io/en/stable/installation_setup.html
i.e.
pip install --upgrade objax
CUDA_VERSION=11.6
pip install -f https://storage.googleapis.com/jax-releases/jax_releases.html jaxlib==`python3 -c 'import jaxlib; print(jaxlib.__version__)'`+cuda`echo $CUDA_VERSION | sed s:\\\.::g`
However the last step doesn't work. I get the following error message:
ERROR: Could not find a version that satisfies the requirement jaxlib==0.3.15+cuda116 (from versions: 0.1.32, 0.1.40, 0.1.41, 0.1.42, 0.1.43, 0.1.44, 0.1.46, 0.1.50, 0.1.51, 0.1.52, 0.1.55, 0.1.56, 0.1.57, 0.1.58, 0.1.59, 0.1.60, 0.1.61, 0.1.62, 0.1.63, 0.1.64, 0.1.65, 0.1.66, 0.1.67, 0.1.68, 0.1.69, 0.1.70, 0.1.71, 0.1.72, 0.1.73, 0.1.74, 0.1.75, 0.1.76, 0.3.0, 0.3.2, 0.3.5, 0.3.7, 0.3.8, 0.3.10, 0.3.14, 0.3.15) ERROR: No matching distribution found for jaxlib==0.3.15+cuda116
I have tried with multiple versions of python/CUDA, but I always get this error.
Executing pip install --upgrade pip
at the begining does not help.
System description:
- Operating system: Ubuntu 20.04.4 LTS
- CUDA Version: 11.6
- Python version: 3.8.13