This is a repost of my solution, posted on numba discourse https://numba.discourse.group/t/call-scipy-splev-routine-in-numba-jitted-function/1122/7.
I had originally gone ahead with @max9111 suggestion of using objmode. It gave a temporary fix. But, since the code was performance critical, I finally ended up writing a numba version of scipy's 'interpolate.splev' subroutine for the spline interpolation.
import numpy as np
import numba
from scipy import interpolate
import matplotlib.pyplot as plt
import time
# Custom wrap of scipy's splrep
def custom_splrep(x, y, k=3):
"""
Custom wrap of scipy's splrep for calculating spline coefficients,
which also check if the data is equispaced.
"""
# Check if x is equispaced
x_diff = np.diff(x)
equi_spaced = all(np.round(x_diff,5) == np.round(x_diff[0],5))
dx = x_diff[0]
# Calculate knots & coefficients (cubic spline by default)
t,c,k = interpolate.splrep(x,y, k=k)
return (t,c,k,equi_spaced,dx)
# Numba accelerated implementation of scipy's splev
@numba.njit(cache=True)
def numba_splev(x, coeff):
"""
Custom implementation of scipy's splev for spline interpolation,
with additional section for faster search of knot interval, if knots are equispaced.
Spline is extrapolated from the end spans for points not in the support.
"""
t,c,k, equi_spaced, dx = coeff
t0 = t[0]
n = t.size
m = x.size
k1 = k+1
k2 = k1+1
nk1 = n - k1
l = k1
l1 = l+1
y = np.zeros(m)
h = np.zeros(20)
hh = np.zeros(19)
for i in range(m):
# fetch a new x-value arg
arg = x[i]
# search for knot interval t[l] <= arg <= t[l+1]
if(equi_spaced):
l = int((arg-t0)/dx) + k
l = min(max(l, k1), nk1)
else:
while not ((arg >= t[l-1]) or (l1 == k2)):
l1 = l
l = l-1
while not ((arg < t[l1-1]) or (l == nk1)):
l = l1
l1 = l+1
# evaluate the non-zero b-splines at arg.
h[:] = 0.0
hh[:] = 0.0
h[0] = 1.0
for j in range(k):
for ll in range(j+1):
hh[ll] = h[ll]
h[0] = 0.0
for ll in range(j+1):
li = l + ll
lj = li - j - 1
if(t[li] != t[lj]):
f = hh[ll]/(t[li]-t[lj])
h[ll] += f*(t[li]-arg)
h[ll+1] = f*(arg-t[lj])
else:
h[ll+1] = 0.0
break
sp = 0.0
ll = l - 1 - k1
for j in range(k1):
ll += 1
sp += c[ll]*h[j]
y[i] = sp
return y
######################### Testing and comparison #############################
# Generate a data set for interpolation
x, dx = np.linspace(10,100,200, retstep=True)
y = np.sin(x)
# Calculate the cubic spline spline coeff's
coeff_1 = interpolate.splrep(x,y, k=3) # scipy's splrep
coeff_2 = custom_splrep(x,y, k=3) # Custom wrap of scipy's splrep
# Generate data for interpolation and randomize
x2 = np.linspace(0,110,10000)
np.random.shuffle(x2)
# Interpolate
y2 = interpolate.splev(x2, coeff_1) # scipy's splev
y3 = numba_splev(x2, coeff_2) # Numba accelerated implementation of scipy's splev
# Plot data
plt.plot(x,y,'--', linewidth=1.0,color='green', label='data')
plt.plot(x2,y2,'o',color='blue', markersize=2.0, label='scipy splev')
plt.plot(x2,y3,'.',color='red', markersize=1.0, label='numba splev')
plt.legend()
plt.show()
print("\nTime for random interpolations")
# Calculation time evaluation for scipy splev
t1 = time.time()
for n in range(0,10000):
y2 = interpolate.splev(x2, coeff_1)
print("scipy splev", time.time() - t1)
# Calculation time evaluation for numba splev
t1 = time.time()
for n in range(0,10000):
y2 = numba_splev(x2, coeff_2)
print("numba splev",time.time() - t1)
print("\nTime for non random interpolations")
# Generate data for interpolation without randomize
x2 = np.linspace(0,110,10000)
# Calculation time evaluation for scipy splev
t1 = time.time()
for n in range(0,10000):
y2 = interpolate.splev(x2, coeff_1)
print("scipy splev", time.time() - t1)
# Calculation time evaluation for numba splev
t1 = time.time()
for n in range(0,10000):
y2 = numba_splev(x2, coeff_2)
print("numba splev",time.time() - t1)
The above code is optimised for faster knot search if the knots are equispaced.
On my corei7 machine, if the interpolation is done at random values, numba version is faster,
Scipy’s splev = 0.896s
Numba splev = 0.375s
If the interpolation is not done at random values scipy’s version is faster,
Scipy’s splev = 0.281s
Numba splev = 0.375s
Ref : https://github.com/scipy/scipy/tree/v1.7.1/scipy/interpolate/fitpack ,
https://github.com/dbstein/fast_splines