Questions tagged [jax]

JAX allows to write auto-differentiable functions. It provides a NumPy and native Python compatible interface built on composable function transformations. Further optimization happens by automatic vectorization and running code on GPUs/TPUs.

Documentation: https://jax.readthedocs.io

Project repo: https://github.com/google/jax

496 questions
-2
votes
1 answer

Repeating rows from array

I have a problem becase I would like to repeat n time all rows from array(X, Y) without using loops and get array(n*X, Y) import jax.numpy as jnp arr = jnp.array([[12, 14, 12, 0, 1], [0, 14, 12, 0, 1], [0, 0, 12, 0,…
1 2 3
33
34