I have come across some behaviour I don't understand in Jax when trying to do an SVD compression for large arrays. Here is the sample code:
@jit
def jax_compress(L):
U, S, _ = jsc.linalg.svd(L,
full_matrices = False,
lapack_driver = 'gesvd',
check_finite=False,
overwrite_a=True)
maxS=jnp.max(S)
chi = jnp.sum(S/maxS>1E-1)
return chi, jnp.asarray(U)
Jax/jit give an enormous performance increase over SciPy when considering this snippet of code, but ultimately I want to reduce the dimensionality of U, which I do by wrapping it in the function:
def jax_process(A):
chi, U = jax_compress(A)
return U[:,0:chi]
This step is unbelievably costly in terms of computation time, more so than the SciPy equivalent, as can be seen in this comparison:
sc_compress
and sc_process
are the SciPy equivalents to the jax code above. As you can see, it costs almost nothing to slice the arrays in SciPy, but is very expensive when applied to the output of a hit function. Does anyone have some insight to this behaviour?