Questions tagged [jax]

JAX allows to write auto-differentiable functions. It provides a NumPy and native Python compatible interface built on composable function transformations. Further optimization happens by automatic vectorization and running code on GPUs/TPUs.

Documentation: https://jax.readthedocs.io

Project repo: https://github.com/google/jax

496 questions
0
votes
1 answer

(Conv1D) Tensorflow and Jax Resulting Different Outputs for The Same Input

I am trying to use conv1d functions to make a transposed convlotion repectively at jax and tensorflow. I read the documentation of both of jax and tensorflow for the con1d_transposed operation but they are resulting with different outputs for the…
0
votes
1 answer

How to check if a value is in an array while using jax

I have a negative sampling function that I want to use JAX's @jit but everything that I do makes it stop working. The parameters are: key: key to jax.random ratings: a list of 3-tuples (user_id, item_id, 1); user_positives: a list of lists where…
Júlio Guedes
  • 55
  • 1
  • 12
0
votes
1 answer

Memory usage in transforming fine tuning of GPTJ-6b to HuggingFace format

Following this tutorial using TPUs to fine tune GPTJ has worked well. https://github.com/kingoflolz/mesh-transformer-jax/blob/master/howto_finetune.md Why would the step to transform to huggingface format using to_hf_weights.py have an issue with…
Jonathan Hendler
  • 1,239
  • 1
  • 17
  • 23
0
votes
1 answer

Haiku & Jax weights initialisation

In Pytorch the following code can be used to initialise a layer: def init_layer(in_features, out_features): x = nn.Linear(in_features, out_features) limit = 1.0 / math.sqrt(in_features) x.weight = nn.Parameter( …
masha
  • 33
  • 3
0
votes
0 answers

How can I utilize JAX library on my code with numpy take realted error: "NotImplementedError: The 'raise' mode to jnp.take is not supported."

Due to my need to speed up my written code, I have modified that to pure NumPy code to evaluate the runtime in this way and by JAX accelerator in Python. I don't know if my code is appropriate to be accelerated by JAX, but my little previous studies…
Ali_Sh
  • 2,667
  • 3
  • 43
  • 66
0
votes
1 answer

IndexError: Array boolean indices must be concrete

I'm trying to find the non-zero elements of a list using the LAX-backend implementation of nonzero(). from jax import numpy as jnp Gamma = [[1, 1], [1, 0]] print(jnp.nonzero(Gamma[0])) I'm receiving the error IndexError: Array boolean indices…
Blade
  • 984
  • 3
  • 12
  • 34
0
votes
0 answers

JAX pmap is slower than jit(vmap), how to speedup?

I have two fairly complex and independent computations that I want to run on two GPUs with pmap. Surprisingly the pmap-ed version is much slower. I know that doubling the performance is almost impossible but I expected a better performance. Below is…
0
votes
1 answer

Conditioning on elements of matrix in JIT-ted function

I have a function that looks like this @jax.jit def f(R): tr = jnp.trace(R) r00 = R[0, 0] r01 = R[0, 1] r02 = R[0, 2] r10 = R[1, 0] r11 = R[1, 1] …
Carpetfizz
  • 8,707
  • 22
  • 85
  • 146
0
votes
1 answer

Vectorizing a Physics Simulation?

I'm trying to simulate some 2-dimensional particles. Each particle is a circle that has an orientation. The orientation is specified by a 2-dimensional unit vector. In one part of my simulation I'd like to calculate a function of the angle between…
iamsad
  • 9
  • 2
0
votes
1 answer

Is there a way to speed up indexing a vector with JAX?

I am indexing vectors and using JAX, but I have noticed a considerable slow-down compared to numpy when simply indexing arrays. For example, consider making a basic array in JAX numpy and ordinary numpy: import jax.numpy as jnp import numpy as onp…
0
votes
1 answer

NCCL operation ncclGroupEnd() failed: unhandled system error

I am able to run this file vit_jax.ipynb on colab and perform training and run my experiments but when I try to replicate it on my cluster, I am getting an error during training given below. However, the forward pass to calculate accuracy works…
talos1904
  • 952
  • 3
  • 9
  • 24
0
votes
0 answers

Getting error '/usr/bin/bash: line 1: realpath: command not found' when install JAX with CUDA

I am trying to build JAX with CUDA from source on my windows laptop. I have installed MSYS2. I am following the instructions given here However, I am unable to install realpath using pacman -S patch realpath as mentioned in the docs. I am getting…
Siladittya
  • 1,156
  • 2
  • 13
  • 41
0
votes
1 answer

Jax and train Neural Networks

I am a beginner in JAX and I am trying to learn how to train a neural network. I saw some blogs, but as I understood there isn't a library that you can train it easily, like 'fit' as in sklearn. I am interested about classification task, could you…
user15479632
0
votes
1 answer

Profiling JAX code: What is redzone_checker and why does it take so much time?

I have found this post but am still unclear on what the redzone_checker kernel is doing and why. Specifically, should it be taking > 90% of my application's runtime? TensorBoard reports that it is taking the vast majority of the runtime of my JAX…
emprice
  • 912
  • 11
  • 21
0
votes
1 answer

Reimplementing bert-style pooler throws shape error as if length-dimension were still needed

I have trained an off-the-shelf Transformer(). Now I want to use the encoder in order to build a classifier. For that I want to only use the first token's output (bert-style cls-token-result) and run that through a dense layer. What I…
Phillip Bock
  • 1,879
  • 14
  • 23