Questions tagged [jax]

JAX allows to write auto-differentiable functions. It provides a NumPy and native Python compatible interface built on composable function transformations. Further optimization happens by automatic vectorization and running code on GPUs/TPUs.

Documentation: https://jax.readthedocs.io

Project repo: https://github.com/google/jax

496 questions
3
votes
1 answer

vmap in Jax to loop over arguments

Lets suppose I have some function which returns a sum of inputs. @jit def some_func(a,r1,r2): return a + r1 + r2 Now I would like to loop over different values of r1 and r2, save the result and add it to a counter. This is what I mean: a = 0…
Zohim
  • 41
  • 5
3
votes
1 answer

Handle varying shapes in jax numpy arrays (jit compatible)

Important note: I need everything to be jit compatible here, otherwise my problem is trivial :) I have a jax numpy array such as: a = jnp.array([1,5,3,4,5,6,7,2,9]) First I filter it considering a value, let's assume that I only keep values that…
Valentin Macé
  • 1,150
  • 1
  • 10
  • 25
3
votes
2 answers

How could I speed up this looping code by JAX; Finding nearest neighbors for collision

I am trying to use JAX on another SO question to evaluate JAX applicability and performance on the code (There are useful information on that about what the code does). For this purpose, I have modified the code by jax.numpy (jnp) equivalent methods…
Ali_Sh
  • 2,667
  • 3
  • 43
  • 66
3
votes
1 answer

Compute efficiently Hessian matrices in JAX

In JAX's Quickstart tutorial I found that the Hessian matrix can be computed efficiently for a differentiable function fun using the following lines of code: from jax import jacfwd, jacrev def hessian(fun): return…
Gilfoyle
  • 3,282
  • 3
  • 47
  • 83
3
votes
1 answer

Automatic Differentiation with respect to rank-based computations

I'm new to automatic differentiation programming, so this maybe a naive question. Below is a simplified version of what I'm trying to solve. I have two input arrays - a vector A of size N and a matrix B of shape (N, M), as well a parameter vector…
3
votes
1 answer

JAX: avoid just-in-time recompilation for a function evaluated with a varying number of elements along one axis

Is it possible to avoid recompiling a JIT function when the structure of its input remains essentially unchanged, aside from one axis having a varying number of elements? import jax @jax.jit def f(x): print('recompiling') return (x + 10) *…
mutableVoid
  • 1,284
  • 2
  • 11
  • 29
3
votes
2 answers

Multiple `vmap` in JAX?

This may me a very simple thing, but I was wondering how to perform mapping in the following example. Suppose we have a function that we want to evaluate derivative with respect to xt, yt and zt, but it also takes additional parameters xs, ys and…
antelk
  • 29
  • 2
3
votes
1 answer

JIT a least squares loss function in Jax

I have a simple loss function that looks like this def loss(r, x, y): resid = f(r, x) - y return jnp.mean(jnp.square(resid)) I would like to optimize over the parameter r and use some static parameters x and y to…
Carpetfizz
  • 8,707
  • 22
  • 85
  • 146
3
votes
2 answers

Websockets messages only sent at the end and not in instances using async / await, yield in nested for loops

I have a computationally heavy process that takes several minutes to complete in the server. So I want to send the results of every iteration to the client via websockets. The overall application works but my problem is that all the messages are…
1cgonza
  • 1,589
  • 10
  • 20
3
votes
1 answer

Debug array in jax vmap function

Dear jax experts I need your kind help. Here is a working example (I have follow the advise to simplify my code, although I am not an expert on jax neither on Python to guess what is the heart of the mechanism involved in vmap) def…
Jean-Eric
  • 372
  • 2
  • 14
3
votes
2 answers

Not able to import python package jax in Google TPU

I am working on linux console and typing python takes me into the python console. When I use the following command in TPU machine import jax then it generates following mss and get out of the python prompt. paramjeetsingh80@t1v-n-1c883486-w-0:~$…
3
votes
1 answer

JAX batching with different lengths

I have a function compute(x) where x is a jnp.ndarray. Now, I want to use vmap to transform it into a function that takes a batch of arrays x[i], and then jit to speed it up. compute(x) is something like: def compute(x): # ... some code y =…
Federico Taschin
  • 2,027
  • 3
  • 15
  • 28
3
votes
1 answer

is it possible to jit a function which uses jax.numpy.unique?

The following code does not work: def get_unique(arr): return jnp.unique(arr) get_unique = jit(get_unique) get_unique(jnp.ones((10,))) The error message compains about the use of jnp.unique: FilteredStackTrace:…
lhk
  • 27,458
  • 30
  • 122
  • 201
3
votes
1 answer

what exactly is `xla_client` in the jax library?

If you read the jax source code you'll hit something called xla_client. Often imported like this from . import xla_client This implies that xla_client is a python module, but I can't find any file with that name or reference to a variable of that…
3
votes
1 answer

Jax cannot find the static argnums

This is related with this question. I manage to make the most of the code work, except one of the strange thing. Here is the modified code. import jax.numpy as jnp from jax import grad, jit, value_and_grad from jax import vmap, pmap from jax import…
RanWang
  • 310
  • 2
  • 12