Questions tagged [jax]

JAX allows to write auto-differentiable functions. It provides a NumPy and native Python compatible interface built on composable function transformations. Further optimization happens by automatic vectorization and running code on GPUs/TPUs.

Documentation: https://jax.readthedocs.io

Project repo: https://github.com/google/jax

496 questions
0
votes
1 answer

Bazel build of JAX fails with missing dependency declarations

I am trying build cuda-enabled JAX from source on a cluster with CentOS version7. In the jax home directory, I run: python build/build.py --enable_cuda --cuda_path=$CUDA_HOME --cudnn_path=$CUDNN_HOME Here are my specs: cuda version: 11.6 cudnn…
rice_cake
  • 1
  • 2
0
votes
1 answer

jax minimization with stochastically estimated gradients

I'm trying to use the bfgs optimizer from tensorflow_probability.substrates.jax and from jax.scipy.optimize.minimize to minimize a function f which is estimated from pseudo-random samples and has a jax.random.PRNGKey as argument. To use this…
0
votes
1 answer

Error through converting a jax numpy pre-trained weight to h5 weight

I have downloaded a jax numpy weight file with npz suffix, but when I tried to convert it to h5 file I recieved this error: import jax.numpy as jnp import h5py import tensorflow as tf BASE_URL =…
MediaJ
  • 41
  • 7
0
votes
1 answer

How to create a physics-informed neural network (PINN) using jax

I am trying to create a physics-informed neural network (PINN) in JAX. I want to differentiate the defined model (neural network) by the input (x). If I set model to jax.grad(params), I get an error. If I set model to jax.grad(model), I don't get an…
hohohohoho
  • 39
  • 1
  • 4
0
votes
0 answers

Issue with importing SCVI package into an jupyter environment [partially initialized module 'jax' has no attribute 'version']

I am analyzing single cell RNA-seq data in jupyter utilizing python3. In order to delineate differential expression between cells/samples/etc. scanpy and SCVI is needed to perform the task i need. I am able to install SCVI and I installed scvi-tools…
Samuel S
  • 1
  • 1
0
votes
1 answer

How to batch Jax array and vmap

I've a function that works on batches of an array defined like this def batched_fn(X): @jax.jit def apply(Xb): Xb_out = ... return Xb_out return apply The apply function will use X and Xb to calculate Xb_out and can be called on a…
bachr
  • 5,780
  • 12
  • 57
  • 92
0
votes
1 answer

Getting incorrect output from the flax model's init call

I am trying to create a simple neural network using flax, as shown below. However, the params frozen dict I receive as the output to of model.init is empty instead of having the parameters of the neural network. Also the the type(predictions) is…
Bunny Rabbit
  • 8,213
  • 16
  • 66
  • 106
0
votes
1 answer

Cannot compute simple gradient of lambda function in JAX

I'm trying to compute the gradient of a lambda function that involves other gradients of functions, but the computation is hanging and I do not understand why. In particular, the code below successfully computes f_next, but not its derivative…
Genoveffo
  • 15
  • 3
0
votes
1 answer

Problem with jax tree_multimap when importing pybamm package

I am trying to import pybamm and its throwing the following error at meenter image description here Following the suggestion of a previous post I have tried !pip install "jax<=0.3.16" "jaxlib<=0.3.16" but it remains the same error.
Spencer
  • 1
  • 1
0
votes
1 answer

TPU not found on Google VM (jax version 0.2.16)

I'm running a TPU v3-8 VM on Google. On the VM, I installed jax with pip install "jax[tpu]==0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html. Unfortunately, I'm getting the message No GPU/TPU found, falling back to CPU,…
BlackHawk
  • 719
  • 1
  • 6
  • 18
0
votes
1 answer

(Jax) Reshape pytree containing arrays of different shapes

I have a pytree containing arrays that have different shapes, for example it contains: observations of shape (5, 3, 250, 23) dones of shape (5, 3, 250) I want to reshape my pytree so that the first two dimensions are merged, which would give…
Valentin Macé
  • 1,150
  • 1
  • 10
  • 25
0
votes
0 answers

Flax Memory Consumption in Backwards Pass

recently I build my first model in Flax. The forward pass worked fine, but i experienced OOM errors during the backward pass. Originally I had split my model into several small classes, each of which implemented as its own flax model inheriting from…
Simon P.
  • 105
  • 7
0
votes
0 answers

How do you install JAX on a Google Coral TPU?

The command I would like to use is pip install jax[tpu] But this rollbacks the version indefinitely (notice how instead of just installing the latest version it tries and then downloads the next version, and so forth): Collecting jax[tpu] …
d-man
  • 476
  • 4
  • 24
0
votes
1 answer

Indexing a BCOO in Jax

I came across another problem in my attempts to learn jax: I have a sparse BCOO array, and an array holding indices. I need to obtain all values at these indices in the BCOO array. It would be ideal if the returned array would be a sparse BCOO as…
Simon P.
  • 105
  • 7
0
votes
0 answers

Creating custom Haiku model parameters

I am trying to create custom parameters for a Haiku hk.Module. I know that hk.get_parameter() does create new parameters/returns existing parameters, however, I also want to modify the values of those parameters before they are returned when calling…
Yusha Arif
  • 11
  • 1