Questions tagged [jax]

JAX allows to write auto-differentiable functions. It provides a NumPy and native Python compatible interface built on composable function transformations. Further optimization happens by automatic vectorization and running code on GPUs/TPUs.

Documentation: https://jax.readthedocs.io

Project repo: https://github.com/google/jax

496 questions
3
votes
1 answer

Jaxlib pip installation failure

From the command line, I've tried following this installation tutorial I'd like to avoid building from source if at all possible. Currently, I'm not sure what the issue is. Could anyone verify that they get the same/different response when trying to…
jbuddy_13
  • 902
  • 2
  • 12
  • 34
3
votes
2 answers

Performance drop when slicing jax.numpy arrays

I have come across some behaviour I don't understand in Jax when trying to do an SVD compression for large arrays. Here is the sample code: @jit def jax_compress(L): U, S, _ = jsc.linalg.svd(L, full_matrices = False, lapack_driver =…
Jiles
  • 199
  • 11
3
votes
1 answer

Nontransitive subclassing with numpy and jax

My question is simple: >>> isinstance(x, jax.numpy.ndarray) True >>> issubclass(jax.numpy.ndarray, numpy.ndarray) True >>> isinstance(x, numpy.ndarray) False ? And now I will ramble so SE will accept my reasonable question.
Archaick
  • 133
  • 1
  • 3
3
votes
1 answer

How do I save an optimizer state of JAX trained model?

I am playing with the mnist_vae example and can't figure out how to properly save/load weights of the trained model. enc_init_rng, dec_init_rng = random.split(random.PRNGKey(2)) _, init_encoder_params = encoder_init(enc_init_rng, (batch_size, 28 *…
egorssed
  • 31
  • 1
  • 2
3
votes
1 answer

Conditional update in JAX?

In autograd/numpy I could do: q[q<0] = 0.0 How can I do the same thing in JAX? I tried import numpy as onp and using that to create arrays, but that doesn't seem to work.
Andriy Drozdyuk
  • 58,435
  • 50
  • 171
  • 272
3
votes
1 answer

Get and Post API call in java with basic authentication

I want to call GET and POST API in java without using any framework. I need to use basic authentication. Can anybody help me with some tutorial link. In google I found code only in spring framework, But I am not using Spring. I am looking for code…
Shruti sharma
  • 199
  • 6
  • 21
  • 67
3
votes
1 answer

Google JAX 1D convolutional neural network

I'm trying to implement a 1D convolutional neural network in Google Jax with stax.GeneralConv() (https://jax.readthedocs.io/en/latest/_modules/jax/experimental/stax.html#GeneralConv). I have a 1 dimensional input array with 18 and output array with…
salutnomo
  • 31
  • 1
3
votes
1 answer

Find gradient of a function: Sympy vs. Jax

I have a function Black_Cox() which calls other functions as shown below: import numpy as np from scipy import stats # Parameters D = 100 r = 0.05 γ = 0.1 # Normal CDF N = lambda x: stats.norm.cdf(x) H = lambda V, T, L, σ: np.exp(-r*T) * N(…
Sandu Ursu
  • 1,181
  • 1
  • 18
  • 28
2
votes
1 answer

Is it possible to obtain values from jax traced arrays with dynamicjaxprtrace level larger than 1 using any of the callback functions?

So I have a program that have multiple functions with its own jax calls and here is the main function: @partial(jax.jit, static_argnames=("numberOfVoxels",)) def process_valid_voxels(numberOfVoxels, voxelPositions, voxelLikelihoods, ps, t, M,…
kalinka227
  • 25
  • 4
2
votes
1 answer

XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed

Description The OUTPUT: 2023-07-31 01:53:45.016563: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:427] Loaded runtime CuDNN library: 8.5.0 but source was compiled with: 8.6.0. CuDNN library needs to have matching major version and equal or…
TushanJing
  • 21
  • 1
2
votes
2 answers

Optimization and speed enhancement of function using JAX

The code performs Gaussian blurring on the image intensityRefracted2DF using a Gaussian kernel centered at each pixel. The Gaussian kernel is determined by the values in the darkField array, where each value represents the standard deviation (sigma)…
2
votes
1 answer

How to make a function a valid jax type?

When I pass an object created using the following function function into a jax.lax.scan function: def logdensity_create(model, centeredness = None, varname = None): if centeredness is not None: model = reparam(model, config={varname:…
imk
  • 133
  • 6
2
votes
1 answer

Creating a jax array using existing jax arrays of different lengths throws error

I am using the following code to set a particular row of a jax 2D array to a particular value using jax arrays: zeros_array = jnp.zeros((3, 8)) value = jnp.array([1,2,3,4]) value_2 = jnp.array([1]) value_3 = jnp.array([1,2]) values =…
imk
  • 133
  • 6
2
votes
1 answer

JAX vmap vs pmap vs Python multiprocessing

I am rewriting some code from pure Python to JAX. I have gotten to the point where in my old code, I was using Python's multiprocessing module to parallelize the evaluation of a function over all of the CPU cores in a single node as follows: # start…
2
votes
1 answer

Issues while using JAX to minimize the Lennard-Jones potential for two points and the force (gradient of the potential)--result doesn't match

I am trying to use the minimization function in JAX to find the distance of two points satisfying Lennard-Jones potential E = 2(1/r^4-1/r^2) and I can succssfully get the result: [-0.20710678 1.20710678], which r = 1.41 as expected. However, next I…