I have a function that I want to accelerate using Numba (it computes the log-likelihood sum of residual given cov-var matrices, just for the context but this is not important for the question)
@jit(nopython=True)
def log_ll_norm_multivar(sigma, epsilon, mean=None) -> float:
"""
This function computes the log-likelihood of a multivariate normal law applied to t observations of n size
Args:
sigma : the variance-covariance matrix, at each t, or constant. Must be ndarray(n,n) or ndarray(t,n,n)
If it is (n,n), it will be copied at all times to have a (t,n,n)
epsilon : ndarray(t, n) residuals
Returns:
float : Sum of the log likelihood of the residual, given the sigma variance-covariance matrices
"""
t_max, n = epsilon.shape
if sigma.shape == (n, n):
sigma = np.array([sigma for _ in range(0, t_max)])
if sigma.shape != (t_max, n, n):
raise IllegalParameterException("Sigma shape must be t*n*n")
if mean is None:
mean = np.zeros((t_max, n))
if mean.shape != (t_max, n):
raise Exception("If provided, mean must be of shape (T,n)")
epsilon_centered = epsilon - mean
sum_det_sigma = np.sum(np.log(np.linalg.det(sigma)))
inv_sigma = inv(sigma)
third_term = np.array([
epsilon_centered[t].transpose()
.dot(inv_sigma[t])
.dot(epsilon_centered[t])
for t in range(0, t_max)
]).sum()
return -1 / 2 * (t_max * n * log(np.pi * 2) + sum_det_sigma + third_term)
Numba fails to "compile" line sum_det_sigma = np.sum(np.log(np.linalg.det(sigma)))
and says:
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<function sum at 0x10d1dfdc0>) found for signature:
>>> sum(float64)
There are 2 candidate implementations:
- Of which 2 did not match due to:
Overload of function 'sum': File: numba/core/typing/npydecl.py: Line 379.
With argument(s): '(float64)':
No match.
....
def log_ll_norm_multivar(sigma, epsilon, mean=None) -> float:
<source elided>
np.sum(np.log(np.linalg.det(sigma)))
After I debug the code without @jit(nopython=True)
, it appears that np.log(np.linalg.det(sigma))
is a np.array of shape(1000,), so we're not in the context of this post : Numba nopython error with np.sum where the data on which np.sum
was applied was a scalar.
Just to be sure, I tried this code:
@jit(nopython=True)
def test():
arr_log = np.log(np.ones((1000,), dtype=np.float64))
return arr_log.sum()
And it works perfectly.
What is going on with my code ?