0

This implementation is correct for struve function when checked with scipy.special implementation but since this is not implemented in jax.scipy I translated from a matlab implementation. I am struggling to implement in Jax so that it supports reverse-mode differentiation as well which mean it should be implemented in pure jax and not just pure python? Thanks

import math
def struveh0(x):
    print(x)
    pi = math.pi
    r = 1.0e0
    sh0 = 0
   
    if x <= 20.0e0:
        s = 1.0e0
        a0 = 2.0 * x/ pi
        for k in range(1,60):
            term1 = x/ (2.0 * k + 1.0)
            r =-r * term1 * x / (2.0 * k + 1.0)
            s = s + r
            if abs(r)<(abs(s) * 1.0e-12):
              break
        sh0 = a0 * s

    if x > 20.0e0:
        r = 1.0e0
        s = 1.0e0
        km = min(25, max(math.floor(0.5 * (x + 1.0)), 0))
        km = int(km.item())
        km = km + 1

        for k in range(1, km):
            r = -r * ((2.0 * k - 1.0) / x) ** 2
            s = s + r
            if abs(r)<(abs(s) * 1.0e-12):
                break
        t = 4.0 / x
        t2 = t * t
        p0 = (
            (((-0.37043e-5 * t2 + 0.173565e-4) * t2 - 0.487613e-4) * t2 + 0.17343e-3)
            * t2
            - 0.1753062e-2
        ) * t2 + 0.3989422793e0
       
        q0=t*(((((.32312e-5*t2-.142078e-4)*t2+.342468e-4)*t2-.869791e-4)*t2+.4564324e-3)*t2-.0124669441e0)

        ta0=x-0.25e0*pi;
        by0=2.0e0/math.sqrt(x)*(p0*math.sin(ta0)+q0*math.cos(ta0))

        sh0 = 2.0e0/(pi*x)*s+by0
    
    return sh0


struveh0(10)

Edited: Originally I posted the jax version but that gives me index error and concretization error when I use that within another function that gets vmap across two list.

Kapil
  • 81
  • 5
  • I don't entirely understand your code. It looks like you're trying to use the shape and size of boolean arrays as a proxy for their value, which will not work. You might find [JAX Sharp Bits: Control Flow](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow) to be a useful resource. – jakevdp May 03 '23 at 01:46
  • yes, I can provide a pure python implementation of it. I have edited the question. – Kapil May 03 '23 at 18:33
  • I see, thanks. In order to use this with JAX transformations, you'll have to replace your `if` statements with `lax.cond` statements, and replace your loops with `lax.fori_loop` or `lax.while_loop` constructs. There is more information at the link I shared above. – jakevdp May 03 '23 at 18:59

0 Answers0