I'm fiddling around with JAX, and I came across two different results by just using the jit decorator
import jax
import jax.numpy as jnp
import jax.scipy.stats as jstats
def jitless_log_likelihood(x, mu, sigma):
return jnp.sum(jnp.log(jstats.multivariate_normal.pdf(x, mean=mu, cov=sigma)))
@jax.jit
def log_likelihood(x, mu, sigma):
return jitless_log_likelihood(x, mu, sigma)
key = jax.random.PRNGKey(0)
M = 10000
x = jax.random.normal(key, (10,M))
mu = jnp.array([0]*M)
sigma = jnp.identity(M)
print(jitless_log_likelihood(x, mu, sigma))
print(log_likelihood(x, mu, sigma))
In my CPU, the following code yields
-141839.1 -inf
Why is this happening?