Questions tagged [jax]

JAX allows to write auto-differentiable functions. It provides a NumPy and native Python compatible interface built on composable function transformations. Further optimization happens by automatic vectorization and running code on GPUs/TPUs.

Documentation: https://jax.readthedocs.io

Project repo: https://github.com/google/jax

496 questions
4
votes
1 answer

why when i have a np.power in my function jax.grad can't give me the derivitives?

I want to train a simple linear model. these below x and y are my data. import numpy as np x = np.linspace(0,1,100) y = 2 * x + 3 + np.random.randn(100) f is a function that calculates mean square error over all data. def f(params, x, y): return…
kankan256
  • 210
  • 1
  • 4
  • 18
4
votes
0 answers

Saving Gradient in Backward Pass Google-JAX

I am using JAX to implement a simple neural network (NN) and I want to access and save the gradients from the backward pass for further analysis after the NN ran. I can access and look at the gradients temporarily with the python debugger (as long…
4
votes
1 answer

Efficient way to compute Jacobian x Jacobian.T

Assume J is the Jacobian of some function f with respect to some parameters. Are there efficient ways (in PyTorch or perhaps Jax) to have a function that takes two inputs (x1 and x2) and computes J(x1)*J(x2).transpose() without instantiating the…
Milad
  • 4,901
  • 5
  • 32
  • 43
3
votes
2 answers

implementing if-then-elif-then-else in jax

I'm just starting to use JAX, and I wonder—what would be the right way to implement if-then-elif-then-else in JAX/Python? For example, given input arrays: n = [5, 4, 3, 2] and k = [3, 3, 3, 3], I need to implement the following pseudo-code: def…
Terry
  • 310
  • 3
  • 9
3
votes
1 answer

What are the tradeoffs between jax.lax.map and jax.vmap?

This Github issue hints that there are tradeoffs in performance / memory / compilation time when choosing between jax.lax.map and jax.vmap. What are the specific details of these tradeoffs with respect to both GPUs and CPUs?
3
votes
1 answer

Updating entire row or column of a 2D array in JAX

I'm new to JAX and writing code that JIT compiles is proving to be quite hard for me. I am trying to achieve the following: Given an (n,n) array mat in JAX, I would like to add a (1,n) or an (n,1) array to an arbitrary row or column, respectively,…
3
votes
2 answers

JAX: JIT compatible sparse matrix slicing

I have a boolean sparse matrix that I represent with row indices and column indices of True values. import numpy as np import jax from jax import numpy as jnp N = 10000 M = 1000 X = np.random.randint(0, 100, size=(N, M)) == 0 # data setup rows,…
3
votes
2 answers

Purpose of stop gradient in `jax.nn.softmax`?

jax.nn.softmax is defined as: def softmax(x: Array, axis: Optional[Union[int, Tuple[int, ...]]] = -1, where: Optional[Array] = None, initial: Optional[Array] = None) -> Array: x_max = jnp.max(x, axis,…
Jay Mody
  • 3,727
  • 1
  • 11
  • 27
3
votes
2 answers

Vectorize jax.lax.cond with vmap

Hi why can't I vectorize the condition function to apply for a list of boolean? or is there something else going on here? DK = jnp.array([[True],[True],[False],[True]]) f1 = lambda x: 1 f2 = lambda y: 0 cond = lambda dk: jax.lax.cond(dk,f1,f2) vcond…
Kapil
  • 81
  • 5
3
votes
1 answer

How to vectorize a function over a list of unequal length arrays in JAX

This is a minimal example of the real larger problem I am facing. Consider the function below: import jax.numpy as jnp def test(x): return jnp.sum(x) I tried to vectorize it by: v_test = jax.vmap(test) My inputs to test look like: x1 =…
MOON
  • 2,516
  • 4
  • 31
  • 49
3
votes
2 answers

Why is JAX's `split()` so slow at first call?

jax.numpy.split can be used to segment an array into equal-length segments with a remainder in the last element. e.g. splitting an array of 5000 elements into segments of 10: array = jnp.ones(5000) segment_size = 10 split_indices =…
kdbanman
  • 10,161
  • 10
  • 46
  • 78
3
votes
1 answer

module 'jax' has no attribute 'tree_multimap' in AlphaFold2 CoLab

I am attempting to model a protein using an AlphaFold2 (AlphaFold v2.1.0.) CoLab (https://colab.research.google.com/github/deepmind/alphafold/blob/main/notebooks/AlphaFold.ipynb#scrollTo=pc5-mbsX9PZC). I have done this successfully on 9/2/2022.…
Dr. Wilson
  • 31
  • 1
  • 2
3
votes
1 answer

JAX: Why not using @jit yields -inf value but using it doesn't?

I'm fiddling around with JAX, and I came across two different results by just using the jit decorator import jax import jax.numpy as jnp import jax.scipy.stats as jstats def jitless_log_likelihood(x, mu, sigma): return…
ABaron
  • 124
  • 7
3
votes
2 answers

What is the recommended way to do embeddings in jax?

So I mean something where you have a categorical feature $X$ (suppose you have turned it into ints already) and say you want to embed that in some dimension using the features $A$ where $A$ is arity x n_embed. What is the usual way to do this? Is…
mathtick
  • 6,487
  • 13
  • 56
  • 101
3
votes
2 answers

jax linear solve issues

I am currently trying to implement my work within the jax-framework. However I am now encountering an error using the linear solve function from jax. Here is an example taken directly from the numpy linear algebra documentation page: import numpy as…
mabso
  • 33
  • 4
1 2
3
33 34