I have been trying to get the autograd/linear algebra package JAX to work on my old Macbook (CPU only, OS X 10.11.6), but I have been running into a series of problems along the way. I'm using conda 22.11.1 with Python 3.9.12.
Trying to install JAX with conda using conda install jax -c conda-forge
does not work so well. It fails to solve the environment for a long time before finally installing old, mismatched versions of jax and jaxlib. The incompatible versions of jax/jaxlib then raise an error in the jupyter notebook.
After conda failed, I tried manually retrieving the latest versions using python -m pip install jax==0.3.25 jaxlib==0.3.25
(in the base conda env) but this returns the error No matching distribution found for jaxlib==0.3.25
(potentially OS related).
Finally, I uninstalled jax/jaxlib and then installed 0.3.10 for both. This approach got the packages to install, but now the Python kernel crashes with Segmentation fault 11 the moment jax or jax.numpy is imported.
Does anyone know what is going on, particularly with the Segmentation fault? Are these problems all related to my OS X version? Are there any workarounds I could use to get JAX up and running?
Edit: I am far from a conda expert, but here is my attempt to create a minimal environment that reproduces the problem. I ran conda create -n jax_env python=3.9
and conda activate jax_env
. This installed the packages
bzip2 1.0.8 h0d85af4_4 conda-forge
ca-certificates 2022.12.7 h033912b_0 conda-forge
libffi 3.4.2 h0d85af4_5 conda-forge
libsqlite 3.40.0 ha978bb4_0 conda-forge
libzlib 1.2.13 hfd90126_4 conda-forge
ncurses 6.3 h96cf925_1 conda-forge
openssl 3.0.7 hfd90126_1 conda-forge
pip 22.3.1 pyhd8ed1ab_0 conda-forge
python 3.9.15 h709bd14_0_cpython conda-forge
readline 8.1.2 h3899abd_0 conda-forge
setuptools 65.5.1 pyhd8ed1ab_0 conda-forge
tk 8.6.12 h5dbffcc_0 conda-forge
tzdata 2022g h191b570_0 conda-forge
wheel 0.38.4 pyhd8ed1ab_0 conda-forge
xz 5.2.6 h775f41a_0 conda-forge
Then, I ran python -m pip install jax==0.3.10 jaxlib==0.3.10
. This added the packages
absl-py 1.3.0 pypi_0 pypi
flatbuffers 2.0.7 pypi_0 pypi
jax 0.3.10 pypi_0 pypi
jaxlib 0.3.10 pypi_0 pypi
numpy 1.23.5 pypi_0 pypi
opt-einsum 3.3.0 pypi_0 pypi
scipy 1.9.3 pypi_0 pypi
typing-extensions 4.4.0 pypi_0 pypi
Then, running python
and attempting to import jax.numpy causes the same segmentation fault 11.
Edit 2: I tried making a third environment, starting the same way but with conda install jax=0.3.10 -c conda-forge
instead of pip. The environment solved faster this time, the packages+dependencies were installed. Interestingly, this method doesn't lead to a segmentation fault, but instead the following error:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/username/anaconda3/envs/jax_env/lib/python3.9/site-packages/jax/__init__.py", line 35, in <module>
from jax import config as _config_module
File "/Users/username/anaconda3/envs/jax_env/lib/python3.9/site-packages/jax/config.py", line 17, in <module>
from jax._src.config import config
File "/Users/username/anaconda3/envs/jax_env/lib/python3.9/site-packages/jax/_src/config.py", line 27, in <module>
from jax._src import lib
File "/Users/username/anaconda3/envs/jax_env/lib/python3.9/site-packages/jax/_src/lib/__init__.py", line 114, in <module>
import jaxlib.xla_client as xla_client
File "/Users/username/anaconda3/envs/jax_env/lib/python3.9/site-packages/jaxlib/xla_client.py", line 25, in <module>
from . import xla_extension as _xla
ImportError: dlopen(/Users/username/anaconda3/envs/jax_env/lib/python3.9/site-packages/jaxlib/xla_extension.so, 2): Symbol not found: _SecKeyCopyExternalRepresentation
Referenced from: /Users/username/anaconda3/envs/jax_env/lib/python3.9/site-packages/jaxlib/xla_extension.so
Expected in: /System/Library/Frameworks/Security.framework/Versions/A/Security
in /Users/username/anaconda3/envs/jax_env/lib/python3.9/site-packages/jaxlib/xla_extension.so