4

How can I generate random numbers between 0 and 1 in jax? Basically I am looking to replicate the following function from numpy in jax.

np.random.random(1000)
Bunny Rabbit
  • 8,213
  • 16
  • 66
  • 106

1 Answers1

3

The equivalent in jax would be

from jax import random
key = random.PRNGKey(758493)  # Random seed is explicit in JAX
random.uniform(key, shape=(1000,))

For more information, see the documentation of the jax.random module.

Also be aware that JAX's random number generator does not maintain any sort of global state, so you'll need to think about it a bit differently than you may be accustomed to in NumPy. For more background on this, see JAX Sharp Bits: Random Numbers.

jakevdp
  • 77,104
  • 11
  • 125
  • 160