Questions tagged [jax]

JAX allows to write auto-differentiable functions. It provides a NumPy and native Python compatible interface built on composable function transformations. Further optimization happens by automatic vectorization and running code on GPUs/TPUs.

Documentation: https://jax.readthedocs.io

Project repo: https://github.com/google/jax

496 questions
0
votes
1 answer

Insulate a segment of code from jax tracing

Apologies in advance for how vague this question is (unfortunately I don't know enough about how jax tracing works to phrase it more precisely), but: Is there a way to completely insulate a function or code block from jax tracing? For context, I…
0
votes
1 answer

Convert npz jax weights into keras h5 weights

Is there any way to convert JAX npz pre-trained weights into kers/tf.keras h5 format weights? Couldn't find anything online. Thanks
craft
  • 495
  • 5
  • 16
0
votes
0 answers

Working with google's JAX inside a gunicorn/flask server

I want to serve an application that processes data within the googles JAX framwork with flask and gunicorn. If run inside flask, everything works fine. As soon as I run the application within gunicorn, every jax-related part results in the worker…
Flo Win
  • 154
  • 10
0
votes
0 answers

Training neural network on gradient of input with pytorch

I am currently trying to train a neural network with pytorch, where I try to match the input on the input derivative. I want to do this because this is ensuring a conservative vector field. (Done in training of neural networks for force matching in…
0
votes
0 answers

Very slow jit compile for XLA when using jax

I am using Jax to do some machine learning jobs. Jax uses XLA to do some just-in-time compile for acceleration but the compile itself is too slow on CPU. My situation is that the CPU will only use just a single core to do the compile, which is not…
0
votes
1 answer

jax vmap: enforce correct shape

I'm using vmap to vectorize parts of my code. Here is a minimal example, before the vectorization: dim = 2 def sum(x): a = np.ones((dim,)) return np.dot(x, a) num_samples = 100 samples = np.ones((num_samples, dim)) sum(samples[0]) #…
lhk
  • 27,458
  • 30
  • 122
  • 201
-1
votes
0 answers

How do you install jax on Mac M2?

I am trying to install jax on a Mac M2. pip install jax works, but results in an error that states: RuntimeError: This version of jaxlib was built using AVX instructions, which your CPU and/or operating system do not support. You may be able to…
John Kitchin
  • 2,236
  • 1
  • 17
  • 25
-1
votes
1 answer

fitting a model perfectly using jax in machine learning

Link to text file Hi I am relatively new to Machine learning, I have managed to get a model like in the image attached, I would like to know what I can do more for the model to fit perfectly[model I made],i don't know much about choosing loss…
jeevan
  • 13
  • 2
-1
votes
1 answer

"jaxDecomp installation error" Run setup.py,command execution error #3

error message: CMake Error at CMakeLists.txt:5 (find_package): By not providing "FindNVHPC.cmake" in CMAKE_MODULE_PATH this project has asked CMake to find a package configuration file provided by "NVHPC", but CMake did not find one. Could not find…
-1
votes
1 answer

Why JAX throws an unfiltered stack trace?

I need to jit the train step but when I do I get this error import jax_resnet import jax import jax.numpy as jnp from flax import linen as nn import tensorflow_datasets as tfds from flax.training import train_state import optax import numpy as…
-1
votes
1 answer

If there are two functions - one with jit and other without, and when I iterate them for 100 times, unjit function gives me a less time than jit one

import jax import numpy as np import jax.numpy as jnp a = [] a_jax = [] for i in range(10000): a.append(np.random.randint(1, 5, (5,))) a_jax.append(jnp.array(a[i])) # a_jax = jnp.array(a_jax) @jax.jit def calc_add_with_jit(a, b): return a +…
-1
votes
1 answer

How do you get the gradients of a loss function containing argmax in Jax?

I am facing this issue where I get zero gradients after using argmax in a loss function. I have created a minimal example: import haiku as hk import jax.numpy as jnp import jax.random import optax import chex hidden_dim = 64 input_shape =…
Light
  • 375
  • 4
  • 11
-1
votes
1 answer

What is the "correct" or best way in jax to implement a Dense layer where each layer might or might not have a bias?

For example in jax.experimental.stax there is an Dense layer implemented like this: def Dense(out_dim, W_init=glorot_normal(), b_init=normal()): """Layer constructor function for a dense (fully-connected) layer.""" def init_fun(rng,…
mathtick
  • 6,487
  • 13
  • 56
  • 101
-1
votes
1 answer

All pairwise cross products of the rows of two matrices

I would like to efficiently calculate all pairwise cross products of the rows of two matrices, A and B, which are nx3 and mx3 in size. And would ideally like to achieve this in einsum notation. i.e. the output Matrix C, would be (n X m x…
oracle3001
  • 1,090
  • 19
  • 31
-1
votes
1 answer

No library found under: /usr/local/cuda-9.0/targets/aarch64-linux/lib/libcublasLt.so.9.0

I'm trying to install JAX on the NVIDIA Jetson TX2 and I'm facing considerable issues. I have CUDA 9.0 and it gives me the following error: No library found under: /usr/local/cuda-9.0/targets/aarch64-linux/lib/libcublasLt.so.9.0 So I go looking and…
DumbCoder21
  • 113
  • 2
  • 7
1 2 3
33
34