Questions tagged [jax]

JAX allows to write auto-differentiable functions. It provides a NumPy and native Python compatible interface built on composable function transformations. Further optimization happens by automatic vectorization and running code on GPUs/TPUs.

Documentation: https://jax.readthedocs.io

Project repo: https://github.com/google/jax

496 questions
0
votes
0 answers

Execute list of functions in parallel with JAX

I am performing a scientific computation in JAX that culminates in the sum result = sum([f(pytree) for f in fs]) where fs is a constant list of ~100 functions, all different and working on arrays of different sizes. Each function uses a subset of…
0
votes
0 answers

Optimizing two different set of weights and biases seperately using Jax

I have provided my code below. I am trying to update two different sets of parameters of two Neural Networks seperately. I am not able to figure out how to do that in jax optimizer. For the code given below the loss of objective function is not…
0
votes
0 answers

Loading pickle weights in JAX throws assertion error

enter image description here I am trying to load weights from a saved pickle file. It throws an unspecific assertion error. I was expecting that I can retrieve the opt_state from the saved file. I am not using flax or any other framework.
0
votes
1 answer

best way to mipmap on jax

this is not really a question, but rather I was wondering if anyone has a better way of doing an occupancy grid in Jax (or in another language) for a 3D grid. Here is some working code, does anyone has a better solution (or any problems with my…
0
votes
1 answer

Standard way to save and deploy JAX models?

I am learning JAX from a PyTorch background. I am used to saving serialized PyTorch models as .pt files then deploying them into another application for evaluation. What is the standard way of doing this with JAX? I looked at the Flax guide…
0
votes
1 answer

How to vectorize JAX functions using jit compilation and vmap auto-vectorization

How can I use jit and vmap in JAX to vectorize and speed up the following computation: @jit def distance(X, Y): """Compute distance between two matrices X and Y. Args: X (jax.numpy.ndarray): matrix of shape (n, m) Y…
0
votes
1 answer

ModuleNotFoundError: No module named 'jax.experimental.vectorize'

I had this issue when I was working on a code that was not mine, and I was really overwhelmed by it for a month without a fixed any help online. I have jax-0.4.10 installed and Using a Mac book Pro. I never found any solution online probably because…
Cyebukayire
  • 795
  • 7
  • 14
0
votes
0 answers

Solve a set of PDE in python

I am trying to use jax and python to solve df/dz = 0 for z1,z2...zn. However, it seems my code is not working because all I get is zero(which is the initial guess I put in) I am witing this code as an exercise to get more familiar with Jax. I am…
Heng Yuan
  • 43
  • 4
0
votes
1 answer

CuDNN error when running JAX on GPU with apptainer

I have an application written in Python 3.10+ with JAX that I would like to run on GPU. I can run containers on my local computer cluster using apptainer (but not Docker) which has an NVIDIA A40 GPU. Based on the proposed Dockerfile for JAX I made…
Hylke
  • 75
  • 6
0
votes
0 answers

JAX scan efficient in backprop without TensorArray?

TF provides the TensorArray to make automatic iteration and stacking efficient in scan or while_loop. The naive variant with gathering and concatenating or dynamic updates would be inefficient with backprop, because backprop would keep copies of the…
Albert
  • 65,406
  • 61
  • 242
  • 386
0
votes
0 answers

Data download issue with official Flax Image Net example

I am still trying to understand this official Flax Example. For the convenience of the experiment, I have created my own copy. In the section on running locally, it seems that there is no download command. Therefore when I run python main.py…
RanWang
  • 310
  • 2
  • 12
0
votes
1 answer

How to unroll the training loop so that Jax can train multiple steps in GPU/TPU

When using powerful hardware, especially TPU, it is often preferable to train multiple steps. For example, in TensorFlow, this is possible. with strategy.scope(): model = create_model() optimizer_inner = AdamW(weight_decay=1e-6) …
RanWang
  • 310
  • 2
  • 12
0
votes
0 answers

Struve function implementation in Jax gives error for shape method when a vector is supposed to be passed

This implementation is correct for struve function when checked with scipy.special implementation but since this is not implemented in jax.scipy I translated from a matlab implementation. I am struggling to implement in Jax so that it supports…
Kapil
  • 81
  • 5
0
votes
1 answer

High Memory Consumption in JAX with Nested vmap

I'm working on a problem that involves computing the value of many interpolants on a three dimensional grid using jax. Following standard jax practice, I wrote everything for "single-batch" inputs and then vmap over all interpolants and evaluation…
crypty
  • 53
  • 4
0
votes
1 answer

How to write tensorboard events files without installing / importing TF or PyTorch?

Obviously events-file logging is included with TensorFlow and apparently there's an implementation included with PyTorch, but is there an officially supported standalone implementation of something like SummaryWriter for use outside of these two…
0yy
  • 15
  • 4