2

I have been trying to get the autograd/linear algebra package JAX to work on my old Macbook (CPU only, OS X 10.11.6), but I have been running into a series of problems along the way. I'm using conda 22.11.1 with Python 3.9.12.

Trying to install JAX with conda using conda install jax -c conda-forge does not work so well. It fails to solve the environment for a long time before finally installing old, mismatched versions of jax and jaxlib. The incompatible versions of jax/jaxlib then raise an error in the jupyter notebook.

After conda failed, I tried manually retrieving the latest versions using python -m pip install jax==0.3.25 jaxlib==0.3.25 (in the base conda env) but this returns the error No matching distribution found for jaxlib==0.3.25 (potentially OS related).

Finally, I uninstalled jax/jaxlib and then installed 0.3.10 for both. This approach got the packages to install, but now the Python kernel crashes with Segmentation fault 11 the moment jax or jax.numpy is imported.

Does anyone know what is going on, particularly with the Segmentation fault? Are these problems all related to my OS X version? Are there any workarounds I could use to get JAX up and running?

Edit: I am far from a conda expert, but here is my attempt to create a minimal environment that reproduces the problem. I ran conda create -n jax_env python=3.9 and conda activate jax_env. This installed the packages

bzip2                 1.0.8                h0d85af4_4    conda-forge
ca-certificates       2022.12.7            h033912b_0    conda-forge
libffi                3.4.2                h0d85af4_5    conda-forge
libsqlite             3.40.0               ha978bb4_0    conda-forge
libzlib               1.2.13               hfd90126_4    conda-forge
ncurses               6.3                  h96cf925_1    conda-forge
openssl               3.0.7                hfd90126_1    conda-forge
pip                   22.3.1             pyhd8ed1ab_0    conda-forge
python                3.9.15       h709bd14_0_cpython    conda-forge
readline              8.1.2                h3899abd_0    conda-forge
setuptools            65.5.1             pyhd8ed1ab_0    conda-forge
tk                    8.6.12               h5dbffcc_0    conda-forge
tzdata                2022g                h191b570_0    conda-forge
wheel                 0.38.4             pyhd8ed1ab_0    conda-forge
xz                    5.2.6                h775f41a_0    conda-forge

Then, I ran python -m pip install jax==0.3.10 jaxlib==0.3.10. This added the packages

absl-py                   1.3.0                    pypi_0    pypi
flatbuffers               2.0.7                    pypi_0    pypi
jax                       0.3.10                   pypi_0    pypi
jaxlib                    0.3.10                   pypi_0    pypi
numpy                     1.23.5                   pypi_0    pypi
opt-einsum                3.3.0                    pypi_0    pypi
scipy                     1.9.3                    pypi_0    pypi
typing-extensions         4.4.0                    pypi_0    pypi

Then, running python and attempting to import jax.numpy causes the same segmentation fault 11.

Edit 2: I tried making a third environment, starting the same way but with conda install jax=0.3.10 -c conda-forge instead of pip. The environment solved faster this time, the packages+dependencies were installed. Interestingly, this method doesn't lead to a segmentation fault, but instead the following error:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/username/anaconda3/envs/jax_env/lib/python3.9/site-packages/jax/__init__.py", line 35, in <module>
    from jax import config as _config_module
  File "/Users/username/anaconda3/envs/jax_env/lib/python3.9/site-packages/jax/config.py", line 17, in <module>
    from jax._src.config import config
  File "/Users/username/anaconda3/envs/jax_env/lib/python3.9/site-packages/jax/_src/config.py", line 27, in <module>
    from jax._src import lib
  File "/Users/username/anaconda3/envs/jax_env/lib/python3.9/site-packages/jax/_src/lib/__init__.py", line 114, in <module>
    import jaxlib.xla_client as xla_client
  File "/Users/username/anaconda3/envs/jax_env/lib/python3.9/site-packages/jaxlib/xla_client.py", line 25, in <module>
    from . import xla_extension as _xla
ImportError: dlopen(/Users/username/anaconda3/envs/jax_env/lib/python3.9/site-packages/jaxlib/xla_extension.so, 2): Symbol not found: _SecKeyCopyExternalRepresentation
  Referenced from: /Users/username/anaconda3/envs/jax_env/lib/python3.9/site-packages/jaxlib/xla_extension.so
  Expected in: /System/Library/Frameworks/Security.framework/Versions/A/Security
 in /Users/username/anaconda3/envs/jax_env/lib/python3.9/site-packages/jaxlib/xla_extension.so
cmm0052
  • 31
  • 2

1 Answers1

0

Given the _SecKeyCopyExternalRepresentation error, I suspect the issue is your OSX version is too old for the pre-built binaries available on PyPI; see Error importing tensorflow in anaconda on Mac OSX for a similar question.

I suspect your options will be either to upgrade your operating system, find a pre-built jaxlib for your specific architecture (conda-forge may be an option), or build jaxlib from source on your system.

jakevdp
  • 77,104
  • 11
  • 125
  • 160
  • I would expect sourcing only from Conda Forge (stop using PyPI) should also work, since we still target all builds for macOS 10.9. – merv Dec 14 '22 at 18:20