Questions tagged [jax]

JAX allows to write auto-differentiable functions. It provides a NumPy and native Python compatible interface built on composable function transformations. Further optimization happens by automatic vectorization and running code on GPUs/TPUs.

Documentation: https://jax.readthedocs.io

Project repo: https://github.com/google/jax

496 questions
5
votes
1 answer

How to use jax vmap for nested loops?

I want to use vmap to vectorise this code for performance. def matrix(dataA, dataB): return jnp.array([[func(a, b) for b in dataB] for a in dataA]) matrix(data, data) I tried this: def f(x, y): return func(x, y) mapped =…
akkh
  • 140
  • 1
  • 8
5
votes
1 answer

Jax vectorization: vmap and/or numpy.vectorize?

what are the differences between jax.numpy.vectorizeand jax.vmap? Here is a small snipset import jax import jax.numpy as jnp def f(x): return jnp.exp(-x)*jnp.sin(x) gf = jax.grad(f) x =…
Jean-Eric
  • 372
  • 2
  • 14
5
votes
3 answers

Create a 3D tensor of zeros with exactly one '1' randomly placed on every slice in numpy/jax

I need to create a 3D tensor like this (5,3,2) for example array([[[0, 0], [0, 1], [0, 0]], [[1, 0], [0, 0], [0, 0]], [[0, 0], [1, 0], [0, 0]], [[0, 0], [0, 0], …
Atul Vinayak
  • 466
  • 3
  • 15
5
votes
2 answers

Why is this function slower in JAX vs numpy?

I have the following numpy function as seen below that I'm trying to optimize by using JAX but for whatever reason, it's slower. Could someone point out what I can do to improve the performance here? I suspect it has to do with the list…
DumbCoder21
  • 113
  • 2
  • 7
5
votes
1 answer

What is the difference between JAX, Trax, and TensorRT, in simple terms?

I have been using TensorRT and TensorFlow-TRT to accelerate the inference of my DL algorithms. Then I have heard of: JAX https://github.com/google/jax Trax https://github.com/google/trax Both seem to accelerate DL. But I am having a hard time to…
Aizzaac
  • 3,146
  • 8
  • 29
  • 61
4
votes
2 answers

Is there a way to update multiple indexes of Jax array at once?

Since array is immutable in Jax, so when one updates N indexes, it creates N arrays with x = x.at[idx].set(y) With hundreds of updates per training cycle, it will ultimately create hundreds of arrays if not millions. This seems a little wasteful,…
move37
  • 79
  • 4
4
votes
2 answers

How to improve Julia's performance using just in time compilation (JIT)

I have been playing with JAX (automatic differentiation library in Python) and Zygote (the automatic differentiation library in Julia) to implement Gauss-Newton minimisation method. I came upon the @jit macro in Jax that runs my Python code in…
MOON
  • 2,516
  • 4
  • 31
  • 49
4
votes
1 answer

JAX - jitting functions: parameters vs "global" variables

I've have the following doubt about Jax. I'll use an example from the official optax docs to illustrate it: def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params: opt_state = optimizer.init(params) @jax.jit …
Liuka
  • 289
  • 2
  • 10
4
votes
1 answer

Is it safe to read the value of numpy.empty or jax.numpy.empty?

In Flax, we typically initialize a model by passing in a random vector and let the library figure the correct shape for the parameters via shape inference. For example, this is what the tutorial did def create_train_state(rng, learning_rate,…
nalzok
  • 14,965
  • 21
  • 72
  • 139
4
votes
1 answer

How to generate random numbers between 0 and 1 in jax?

How can I generate random numbers between 0 and 1 in jax? Basically I am looking to replicate the following function from numpy in jax. np.random.random(1000)
Bunny Rabbit
  • 8,213
  • 16
  • 66
  • 106
4
votes
1 answer

Flax much slower than pure Jax for neural nentworks?

for a project I am trying to code up a very simple MLP example, but I noticed that the implementation in flax is about 20 times slower than the pure jax implementation. What am I doing wrong here? import time import jax.numpy as np from jax import…
Luca Thiede
  • 3,229
  • 4
  • 21
  • 32
4
votes
1 answer

Fastest way to multiply and sum 4D array with 2D array in python?

Here's my problem. I have two matrices A and B, with complex entries, of dimensions (n,n,m,m) and (n,n) respectively. Below is the operation I perform to get a matrix C - C = np.sum(B[:,:,None,None]*A, axis=(0,1)) Computing the above once takes…
Prasad Mani
  • 155
  • 5
4
votes
2 answers

Turn a tf.data.Dataset to a jax.numpy iterator

I am interested about training a neural network using JAX. I had a look on tf.data.Dataset, but it provides exclusively tf tensors. I looked for a way to change the dataset into JAX numpy array and I found a lot of implementations that use…
Valentin Goldité
  • 1,040
  • 4
  • 13
4
votes
1 answer

Jax/Flax (very) slow RNN-forward-pass compared to pyTorch?

I recently implemented a two-layer GRU network in Jax and was disappointed by its performance (it was unusable). So, i tried a little speed comparison with Pytorch. Minimal working example This is my minimal working example and the output was…
Simon B
  • 199
  • 1
  • 9
4
votes
2 answers

How to reduce JAX compile time when using for loop?

This is a basic example. @jax.jit def block(arg1, arg2): for x1 in range(cons1): for x2 in range(cons2): for x3 in range(cons3): --do something-- return result When cons are small, the compile-time is around a…
akkh
  • 140
  • 1
  • 8
1
2
3
33 34