This implementation is correct for struve function when checked with scipy.special implementation but since this is not implemented in jax.scipy I translated from a matlab implementation. I am struggling to implement in Jax so that it supports reverse-mode differentiation as well which mean it should be implemented in pure jax and not just pure python? Thanks
import math
def struveh0(x):
print(x)
pi = math.pi
r = 1.0e0
sh0 = 0
if x <= 20.0e0:
s = 1.0e0
a0 = 2.0 * x/ pi
for k in range(1,60):
term1 = x/ (2.0 * k + 1.0)
r =-r * term1 * x / (2.0 * k + 1.0)
s = s + r
if abs(r)<(abs(s) * 1.0e-12):
break
sh0 = a0 * s
if x > 20.0e0:
r = 1.0e0
s = 1.0e0
km = min(25, max(math.floor(0.5 * (x + 1.0)), 0))
km = int(km.item())
km = km + 1
for k in range(1, km):
r = -r * ((2.0 * k - 1.0) / x) ** 2
s = s + r
if abs(r)<(abs(s) * 1.0e-12):
break
t = 4.0 / x
t2 = t * t
p0 = (
(((-0.37043e-5 * t2 + 0.173565e-4) * t2 - 0.487613e-4) * t2 + 0.17343e-3)
* t2
- 0.1753062e-2
) * t2 + 0.3989422793e0
q0=t*(((((.32312e-5*t2-.142078e-4)*t2+.342468e-4)*t2-.869791e-4)*t2+.4564324e-3)*t2-.0124669441e0)
ta0=x-0.25e0*pi;
by0=2.0e0/math.sqrt(x)*(p0*math.sin(ta0)+q0*math.cos(ta0))
sh0 = 2.0e0/(pi*x)*s+by0
return sh0
struveh0(10)
Edited: Originally I posted the jax version but that gives me index error and concretization error when I use that within another function that gets vmap across two list.