If you read the jax source code you'll hit something called xla_client
. Often imported like this
from . import xla_client
This implies that xla_client
is a python module, but I can't find any file with that name or reference to a variable of that name.
I assume that it is related to https://pypi.org/project/jaxlib/, but this package just links back to the jax source code.
Can anybody clue me in?