I have the following numpy function as seen below that I'm trying to optimize by using JAX but for whatever reason, it's slower.
Could someone point out what I can do to improve the performance here? I suspect it has to do with the list comprehension taking place for Cg_new but breaking that apart doesn't yield any further performance gains in JAX.
import numpy as np
def testFunction_numpy(C, Mi, C_new, Mi_new):
Wg_new = np.zeros((len(Mi_new[:,0]), len(Mi[0])))
Cg_new = np.zeros((1, len(Mi[0])))
invertCsensor_new = np.linalg.inv(C_new)
Wg_new = np.dot(invertCsensor_new, Mi_new)
Cg_new = [np.dot(((-0.5*(Mi_new[:,m].conj().T))), (Wg_new[:,m])) for m in range(0, len(Mi[0]))]
return C_new, Mi_new, Wg_new, Cg_new
C = np.random.rand(483,483)
Mi = np.random.rand(483,8)
C_new = np.random.rand(198,198)
Mi_new = np.random.rand(198,8)
%timeit testFunction_numpy(C, Mi, C_new, Mi_new)
#1000 loops, best of 3: 1.73 ms per loop
Here's the JAX equivalent:
import jax.numpy as jnp
import numpy as np
import jax
def testFunction_JAX(C, Mi, C_new, Mi_new):
Wg_new = jnp.zeros((len(Mi_new[:,0]), len(Mi[0])))
Cg_new = jnp.zeros((1, len(Mi[0])))
invertCsensor_new = jnp.linalg.inv(C_new)
Wg_new = jnp.dot(invertCsensor_new, Mi_new)
Cg_new = [jnp.dot(((-0.5*(Mi_new[:,m].conj().T))), (Wg_new[:,m])) for m in range(0, len(Mi[0]))]
return C_new, Mi_new, Wg_new, Cg_new
C = np.random.rand(483,483)
Mi = np.random.rand(483,8)
C_new = np.random.rand(198,198)
Mi_new = np.random.rand(198,8)
C = jnp.asarray(C)
Mi = jnp.asarray(Mi)
C_new = jnp.asarray(C_new)
Mi_new = jnp.asarray(Mi_new)
jitter = jax.jit(testFunction_JAX)
%timeit jitter(C, Mi, C_new, Mi_new)
#1 loop, best of 3: 4.96 ms per loop