Questions tagged [jax]

JAX allows to write auto-differentiable functions. It provides a NumPy and native Python compatible interface built on composable function transformations. Further optimization happens by automatic vectorization and running code on GPUs/TPUs.

Documentation: https://jax.readthedocs.io

Project repo: https://github.com/google/jax

496 questions
0
votes
0 answers

TFlLite didn't find op for builtin opcode 'SELECT' version '3'

I have a TFLite model converted from Jax: converter = tf.lite.TFLiteConverter.experimental_from_jax( [serving_func], [[("encoder_inputs", encoder_inputs), ("decoder_inputs", decoder_inputs), ("primings",…
hlidka
  • 2,086
  • 1
  • 15
  • 14
0
votes
2 answers

How to edit tensorflow dataset?

I imported CIFAR10 dataset via tensorflow_dataset.load(). This gives me
user541396
  • 163
  • 1
  • 9
0
votes
0 answers

Flax implementation of padding_idx from torch.nn.embedding

I have been rewriting some of my pytorch models in jax/flax and came across the issue of converting torch.nn.Embedding to flax.linen.Embed. There does not appear to be a direct translation for pytorch's padding_idx. The keyword essentially 0's the…
0
votes
0 answers

NCCL operation ncclAllReduce(send_buffer, recv_buffer, element_count, dtype, reduce_op, comm, gpu_stream) failed: unhandled cuda error

I am running run_t5_mlm_flax.py with 8 GPU but I get this error (it works with only one GPU). NCCL operation ncclAllReduce(send_buffer, recv_buffer, element_count, dtype, reduce_op, comm, gpu_stream) failed: unhandled cuda error Do you have a…
Antoine23
  • 79
  • 1
  • 5
0
votes
0 answers

Vanishing parameters in MAML JAX (Meta Learning)

I am working on an implementation of MAML (see https://arxiv.org/pdf/1703.03400.pdf) in Jax. When training on a distribution of simple linear regression tasks it seems to perform fine (takes a while to converge but ultimately works). However when…
0
votes
1 answer

Irregular/Inhomogeneous Arrays with JAX

What is the recommended approach to implement array behaviour/methods on irregular/inhomogeneous data (possesses some inherient dimensionality) within JAX? Two principle options come to mind: make homogeneous and use a mask flatten and implement…
DavidJ
  • 326
  • 2
  • 10
0
votes
0 answers

INTERNAL: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc:70: NCCL operation

I am trying to run something using JAX (which works with only 1 GPU). But when I increase the GPU to 4 (32 CPU). I get this error: INTERNAL: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc:70: NCCL operation…
Antoine23
  • 79
  • 1
  • 5
0
votes
0 answers

`vmap`ping a top level function when there are multiple nested functions that require `vmap`ping over different arguments

I am trying to implement a paper called conditionally-structured Gaussian variational inference (CS-GVA). Firstly when implementing variationaL inference algorithms using reparameterization, you express the model parameters theta as a deterministic…
hasco641
  • 69
  • 5
0
votes
0 answers

"Message passing" in Jax, and interplay with asynchronous dispatch?

I have a feed-forward neural network which is basically a composition of N functions. I want to pipeline the training procedure of said network in a multi-device environment by executing some of these functions in one device, forwarding the result…
0
votes
1 answer

Execute Markov chains with tree-structured state in parallel with JAX

I have a Markov chain function implemented in JAX that advances the chain from state s -> s' based on some training data (X_train). def step(state: dict, key, X_train) -> dict: new_state = advance(state, key, X_train) return new_state Here,…
Hylke
  • 75
  • 6
0
votes
1 answer

How to use JAX vmap to efficiently calculate importance sampling estimate

I have code to calculate the off-policy importance sampling estimate commonly used in reinforcement learning. It is not important to know what that is, but for someone who does it might help them understand this question a little better. Basically,…
marvin
  • 581
  • 5
  • 9
0
votes
0 answers

How to change the batch size for a neural network in JAX

Using flax to create a network: def create_train_state(rng, learning_rate, momentum): """Creates initial `TrainState`.""" cnn = CNN() params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params'] tx = optax.sgd(learning_rate, momentum) return…
MichaelMMeskhi
  • 659
  • 8
  • 26
0
votes
1 answer

Example of running a ray.rllib model in a JAX environment?

I am trying to train a DQN agent in an environment coded in JAX, but the initialization of the trainer fails when it first tries to reset the environment (with a not-valid JAX type error). Before getting into the debugging process, I thought of…
blindeyes
  • 409
  • 3
  • 13
0
votes
0 answers

what can casue "LLVM ERROR: Trying to register different dialects for the same namespace"

I'm using jax on ubuntu 18.04. Actually I'm using a new backend of jax which is implemented outside from the XLA source tree. The code can be built successfully, but gets a runtime error "LLVM ERROR: Trying to register different dialects for the…
ipe_zyz
  • 3
  • 1
0
votes
0 answers

Where is the batch dimension with jax.pmap?

I was using jax.pmap to start my training on 4 1080Ti GPUs. I got the training step function step_fn = get_train_step() and then I used pmap, p_train_step = jax.pmap(functools.partial(jax.lax.scan, train_step_fn), …
Noel Tong
  • 1
  • 1