3

I'm fiddling around with JAX, and I came across two different results by just using the jit decorator

import jax
import jax.numpy as jnp
import jax.scipy.stats as jstats


def jitless_log_likelihood(x, mu, sigma):
    return jnp.sum(jnp.log(jstats.multivariate_normal.pdf(x, mean=mu, cov=sigma)))

@jax.jit
def log_likelihood(x, mu, sigma):
    return jitless_log_likelihood(x, mu, sigma)


key = jax.random.PRNGKey(0)

M = 10000

x = jax.random.normal(key, (10,M))
mu = jnp.array([0]*M)
sigma = jnp.identity(M)


print(jitless_log_likelihood(x, mu, sigma))
print(log_likelihood(x, mu, sigma))

In my CPU, the following code yields

-141839.1 -inf

Why is this happening?

ABaron
  • 124
  • 7

1 Answers1

1

In general, JIT compilation will rearrange, combine, or elide operations in your function for efficiency, and this can sometimes change the numerical results of your function. For details and more explanation of this, see FAQ: jit changes the exact numerics of outputs . In this case, the fact that you're computing jnp.log of an exponentiated quantity (multivariate_normal.pdf) is likely the culprit.

Here's a simpler way to see the same behavior:

from jax import jit
import jax.numpy as jnp

def f(x):
  return jnp.log(jnp.exp(x))

x = -141839.1

print(f(x))  # -inf
print(jit(f)(x))  # -141839.1

In the JIT-compiled function, the compiler notices that exp and log cancel out, and so elides those operations, avoiding the underflow that happens if you compute them in sequence.

You can achieve a better-behaved version of your function by avoiding taking the log of the exponential in the first place, using the logpdf function built for this purpose:

def jitless_log_likelihood(x, mu, sigma):
    return jnp.sum(jstats.multivariate_normal.logpdf(x, mean=mu, cov=sigma))
jakevdp
  • 77,104
  • 11
  • 125
  • 160