How can I generate random numbers between 0 and 1 in jax?
Basically I am looking to replicate the following function from numpy
in jax
.
np.random.random(1000)
How can I generate random numbers between 0 and 1 in jax?
Basically I am looking to replicate the following function from numpy
in jax
.
np.random.random(1000)
The equivalent in jax would be
from jax import random
key = random.PRNGKey(758493) # Random seed is explicit in JAX
random.uniform(key, shape=(1000,))
For more information, see the documentation of the jax.random
module.
Also be aware that JAX's random number generator does not maintain any sort of global state, so you'll need to think about it a bit differently than you may be accustomed to in NumPy. For more background on this, see JAX Sharp Bits: Random Numbers.