Questions tagged [jax]

JAX allows to write auto-differentiable functions. It provides a NumPy and native Python compatible interface built on composable function transformations. Further optimization happens by automatic vectorization and running code on GPUs/TPUs.

Documentation: https://jax.readthedocs.io

Project repo: https://github.com/google/jax

496 questions
0
votes
1 answer

How to vmap over specific funciton in jax?

I have this function which works for single vector: def vec_to_board(vector, player, dim, reverse=False): player_board = np.zeros(dim * dim) player_pos = np.argwhere(vector == player) if not reverse: …
Bigyan Karki
  • 303
  • 2
  • 10
0
votes
2 answers

ERROR: No matching distribution found for jaxlib==0.1.67

I need jaxlib==0.1.67 for a project I'm working on, but I can't downgrade. At the moment I have jaxlib==0.1.75 and my program keeps failing due to an error I can't find a solution to either. I compared all versions of the important packages to…
0
votes
1 answer

How can I initialize the hidden state (carry) of a (flax linen) GRUCell as a learnable parameter (e.g. using model.init)

I create a GRU model in Jax using Flax and I initialize the model parameters using model.init as follows: import jax.numpy as np from jax import random import flax.linen as nn from jax.nn import initializers class RNN(nn.Module): n_RNN_units:…
Jabby
  • 43
  • 1
  • 7
0
votes
2 answers

How to install objax with GPU support?

I have followed the objax documentation to install the library with GPU support: https://objax.readthedocs.io/en/stable/installation_setup.html i.e. pip install --upgrade objax CUDA_VERSION=11.6 pip install -f…
Victor
  • 46
  • 7
0
votes
1 answer

I am trying to assign a JAX Tracer object to a NumPy array that requires concrete values - work around needed please

I am new to Jax. I am implementing a variational autoencoder (VAE) using Jax and Flax. During training, I sample a latent code (from the distribution inferred by the encoder, which I implement using compositions of flax.linen.nn modules). Crucially,…
Jabby
  • 43
  • 1
  • 7
0
votes
1 answer

Struggling to understand nested vmaps in JAX

I just about understand unnested vmaps, but try as I may, and I have tried my darnedest, nested vmaps continue to elude me. Take the snippet from this text for example I don't understand what the axis are in this case. Is the nested vmap(kernel,…
Olumide
  • 5,397
  • 10
  • 55
  • 104
0
votes
1 answer

Mixed partial dervative w.r.t. tensor in Pytorch

Question: Is there any working method to calculate gradient of (non-scalar) tensor function? Example Given n by n symmetric matrices X, Y and matrix function Z(X, Y) = torch.mm(X.mm(X), Y) calculate d(dZ/dX)/dY. Expected answer d(dZ/dX)/dY =…
0
votes
1 answer

How to reorder different sets of parameters in dm-haiku

In dm-haiku, parameters of neural networks are defined in dictionaries where keys are module (and submodule) names. If you would like to traverse through the values, there are multiple ways of doing so as shown in this dm-haiku issue. However, the…
VdZ
  • 95
  • 5
0
votes
0 answers

How does Deepmind's Haiku keep track of layers?

I am looking at Deepmind's implementation of a transformer using the Haiku neural network library. I'm confused by their forward function: def build_forward_fn(vocab_size: int, d_model: int, num_heads: int, num_layers: int,…
Foobar
  • 7,458
  • 16
  • 81
  • 161
0
votes
1 answer

How to quantize pre-trained JAX model to TfLite model using tf.lite?

I have a pre-trained JAX model for MAXIM: Image Enhancement. Now to reduce the runtime and use it in production, I'll have to quantize the weights. I have 2 options since there is no direct conversion to ONNX. JAX -> Tensorflow -> ONNX (Help…
Deshwal
  • 3,436
  • 4
  • 35
  • 94
0
votes
1 answer

Can you update parameters of a module from inside the nn.compact of that module? (self modifying networks)

I'm quite new to flax and I was wondering what the correct way is to get this behavior: param = f.init(key,x) new_param, y = f.apply(param,x) Where f is a nn.module instance. Where f might go through multiple operations to get new_param and that…
hal9000
  • 222
  • 2
  • 12
0
votes
1 answer

Is XLA's reshape cheap?

I want to know the performance characteristics of xla::Reshape. Specifically, I can imagine that it could be implemented by simply remapping XlaOp metadata e.g. addresses, rather than creating a whole new XlaOp. Alternatively, does XLA's fusion or…
joel
  • 6,359
  • 2
  • 30
  • 55
0
votes
0 answers

Error when tryind to build JAX in Windows

I am trying to install JAX in Windows, using the steps detailed in this answer. But I am getting an error when trying to build JAX. Here is what I have tried to do: I have downloaded JAX Installed Bazel and msys2 using choco. Restarted…
John Moody
  • 81
  • 10
0
votes
1 answer

Calculate only lower triangular elements of a matrix OR calculation on all possible pairs of the elements of a vector with jax

Is it possible to efficiently run some calculation on all possible pairs of the elements of a vector? I.e. I want to fill the lower triangular elements of a matrix (possibly flattened). I.e. I want to: calculate do_my_calculation(input_vector[i],…
ARF
  • 7,420
  • 8
  • 45
  • 72
0
votes
1 answer

For each element, loop over all previous elements

I have a 2D JAX array containing an image. For each pixel P[y, x] of the image, I would like to loop over all pixels P[y, x-i] to the left of that pixel and reduce those to a single value. The exact reduction computation involves finding a…
sk29910
  • 2,326
  • 1
  • 18
  • 23