Questions tagged [jax]

JAX allows to write auto-differentiable functions. It provides a NumPy and native Python compatible interface built on composable function transformations. Further optimization happens by automatic vectorization and running code on GPUs/TPUs.

Documentation: https://jax.readthedocs.io

Project repo: https://github.com/google/jax

496 questions
0
votes
1 answer

Is there a way to accept a function while taking the gradient using jax.grad?

I am trying to make a neural network-based differential equation solver for the differential equation y' + 2xy = 0. import jax.numpy as jnp import jax import matplotlib.pyplot as plt from tqdm import tqdm import numpy as np def softplus(x): …
JS4137
  • 314
  • 2
  • 11
0
votes
0 answers

Loading FlaxHybridCLIP trained model

System Info transformers version: 4.27.4 Platform: Linux-5.15.0-52-generic-x86_64-with-glibc2.29 Python version: 3.8.10 Huggingface_hub version: 0.13.4 PyTorch version (GPU?): 1.9.0+cpu (False) Tensorflow version (GPU?): 2.9.1 (True) Flax version…
khaled
  • 9
  • 1
  • 4
0
votes
0 answers

Converting jax function with multiple arguments to TensorFlow keras layer fails

This works perfectly: def f_jax(x): return jnp.sin(jnp.cos(x)) f_tf = jax2tf.convert(f_jax, polymorphic_shapes=["(batch, _)"]) f_tf = tf.function(f_tf, autograph=False) f_tf = f_tf.get_concrete_function( tf.TensorSpec(shape=(None, 2),…
galah92
  • 3,621
  • 2
  • 29
  • 55
0
votes
0 answers

Error when trying to plot the curves of the posteriors of lighweight_mmm model

I'm working on a lightweight marketing mix modeling model but I get this error when I try to execute: plot.plot_response_curves(media_mix_model=mmm, target_scaler=target_scaler) Cannot interpret value of type
0
votes
1 answer

crash of a nb 22thMarch23 while it was ok 19thJan2023

I have a problem that I do not know how to tackle it. I have a nb on Colab that was working perfectly on CPU Here is the working nb 19th Jan 23 Then, I was looking to get a new nb working in the same condition (CPU) Dealing with JAX 0.3.25 version…
Jean-Eric
  • 372
  • 2
  • 14
0
votes
1 answer

Official Flax Example "Imagenet" is Bugged

I have tried to use run the official Flax/ImageNet code. Due to the difference in jax version, I have tried two methods. The first example I downgrade jax and jaxlib 0.3.25, and in the second example I change jax and jaxlib to 0.4.4. Then I setup…
RanWang
  • 310
  • 2
  • 12
0
votes
0 answers

Jax fitting MLP gives different result than Tensorflow

I need to build a MLP in Jax, but I get slightly different (and in my opinion inaccurate) result from Jax respect to a MLP created in Tensorflow. In both cases I created a dataset where the y are linear function of X plus a standard gaussian error,…
fabianod
  • 501
  • 4
  • 17
0
votes
0 answers

Is there any way to create the xla_client with CPU and GPU?

I'm using this code as a baseline. But in this xla_client, I can only use the GPU resources. However I want to use the CPU swap(unified) memory as a GPU resources. How could I create the xla_client with CPU memory support version? Is there any…
0
votes
0 answers

Is there a way to convert custom checkpoint from Tensorflow to Jax/Flax/Optax?

I have been searching for this question for hours. Obviously, in Huggingface Transformers, if we don't do anything about the model, we can directly load it into jax/flax/optax. However, what if I want to train a TensorFlow model utilizing its TPU…
RanWang
  • 310
  • 2
  • 12
0
votes
0 answers

Speed up function when inspecting Neighbors in numpy arrays

I am trying to write a function to get coordinates of points of array satisfying certain conditions. Here What I do is trying to find singularities: points having at least one neighbors in each possible state: rest, recovery, excited. This function…
0
votes
2 answers

Nested vmap in pmap - JAX

I currently can run simulations in parallel on one GPU using vmap. To speed things up, I want to batch the simulations over multiple GPU devices using pmap. However, when pmapping the vmapped function I get a tracing error. The code I use to get a…
Anton B
  • 33
  • 4
0
votes
0 answers

python setup.py bdist_wheel did not run successfully, error: invalid command 'bdist_wheel' during when building conda env and installing package JAX

I run this process inside a new conda environment with python 3.8 Installing python environment management utils with #!/bin/bash brew install pyenv || brew upgrade pyenv b'pyenv install 3.11.1\n~/.pyenv/versions/3.11.1/bin/python3.11 -m pip…
Sivan D
  • 61
  • 6
0
votes
1 answer

Defining the correct vectorization axes for JAX vmap with arrays of different shapes and sizes

Following the answer to this post, the following function that 'f_switch' that dynamically switches between multiple functions based on an index array is defined (based on 'jax.lax.switch'): import jax from jax import vmap; import jax.random as…
0
votes
0 answers

Train machine learning model with JAX + ObJAX, met 'ValueError: Unable to cast Python instance to C++ type'

When I use JAX + ObJAX framework trained WRN model, there was an error: 'ValueError: Unable to cast Python instance to C++ type (compile in debug mode for details)', I don't know why... Error information: Traceback (most recent call last): File…
Doren
  • 1
  • 2
0
votes
0 answers

Convert flax model to Pytorch

I have several image classifiers in Flax. For one of the models I have saved the state and for the two others I have saved the parameters as a frozendict with .flax extension. My question is, how could I convert whole models to Pytorch and use these…
m0ss
  • 334
  • 2
  • 4
  • 17