1

I need to compute the (log of the) determinant of the Gram matrix of a matrix A and I was wondering if there is a way to compute this efficiently and in a stable way in Numpy/Scipy.

import numpy as np
m, n = 100, 150
J = np.random.randn(m, n)
np.log(np.det(J.dot(J.T)))

is there some LAPACK routine or some math trick I could use to speed things up and make it more stable?

Euler_Salter
  • 3,271
  • 8
  • 33
  • 74

1 Answers1

2

For better numerical stability, I would suggest to use slogdet, which is your main aim in any case. There may also be a very minimal gain if you use np.inner(J, J) instead of J.dot(J.T). For really speeding things up, I would recommend using jax.numpy.

import numpy as np
import jax
import jax.numpy as jnp

m, n = 100, 150
J = np.random.randn(m, n)

def a(J):
  return np.log(np.linalg.det(J.dot(J.T)))

def b(J):
   return np.linalg.slogdet(np.inner(J, J))[1]

def c(J):
   return jnp.linalg.slogdet(jnp.inner(J, J))[1]

# jit + compile
d = jax.jit(c)
d(J)

# check correctness
print(np.allclose(a(J), b(J))) # True
print(np.allclose(a(J), c(J))) # True
print(np.allclose(a(J), d(J))) # True

Checking run times, on Google Colab:

%timeit -n 1000 -r 10 a(J)
# 240 µs ± 16.2 µs per loop (mean ± std. dev. of 10 runs, 1000 loops each)

%timeit -n 1000 -r 10 b(J)
# 227 µs ± 10.2 µs per loop (mean ± std. dev. of 10 runs, 1000 loops each)

J_dev = jax.device_put(J)

%timeit -n 1000 -r 10 c(J_dev).block_until_ready()
# 112 µs ± 4.46 µs per loop (mean ± std. dev. of 10 runs, 1000 loops each)

%timeit -n 1000 -r 10 d(J_dev).block_until_ready()
# 96.2 µs ± 4.23 µs per loop (mean ± std. dev. of 10 runs, 1000 loops each)

So rougly about ~2x speedup is possible this way.

Mercury
  • 3,417
  • 1
  • 10
  • 35