0

I am trying to use jax and python to solve df/dz = 0 for z1,z2...zn. However, it seems my code is not working because all I get is zero(which is the initial guess I put in) I am witing this code as an exercise to get more familiar with Jax. I am trying to find all the points on a sphere. The code is attach. Am I using minimize function and Jax correctly? Any helps and advice will be greatly appreciated! (G and G2 are just two 3*n matrix containing coordinates and the adjacent_point is just an aray which includes neiborgers)

def distance_3d(p1, p2):
    x1, y1, z1 = p1
    x2, y2, z2 = p2
    return jnp.sqrt((x2 - x1)**2 + (y2 - y1)**2 + (z2 - z1)**2)
def spher_potential_computation_numerical(G, G2, adjacent_points):
    N = len(G)

    def F(z):
        total = 0.0
        for i in range(N):
            p1 = (G[i, 0], G[i, 1], z[i])
            p3 = (G2[i, 0], G2[i, 1], G[i, 2])
            for j in adjacent_points[i]:
                p2 = (G[j, 0], G[j, 1], z[j])
                p4 = (G2[j, 0], G2[j, 1], G[j, 2])
                total += (distance_3d(p1, p2))**2- (distance_3d(p3, p4)**2)**2
        return total

    dF_dz = jit(grad(F))  # Calculate the gradient using JAX's automatic differentiation
    # Initial guess for the z-values
    z_guess = jnp.zeros(N)

    def equations(z_vals):
        return jnp.sum(jnp.abs(dF_dz(z_vals)))
    # Solve the system of PDEs numerically
    res = minimize(equations, z_guess, method='BFGS')
    z_solutions = res.x

    # Update the z-values in the mesh
    G = G.at[:, 2].set(z_solutions)

    return G

The results is the same as initial guess, even if I change the initial guess to all 1

Heng Yuan
  • 43
  • 4
  • 1
    Hi, welcome to StackOverflow! You'll be more likely to get a good answer if you include a [minimal reproducible example](https://stackoverflow.com/help/minimal-reproducible-example). For example, it's not clear what the `minimize` function is in your code, or what shape arrays you're passing to the `spher_potential_computation_numerical` function. Feel free to edit your question with the additional code. – jakevdp May 11 '23 at 23:56
  • Hey. Thanks for the reply. The minimize is the function from "from scipy.optimize import minimize" and I am passing three matrices to the function which are pretty arbitrary. – Heng Yuan May 12 '23 at 01:16
  • Great, did you read the recommendations at the link? Code is much easier to understand than verbal descriptions of code. – jakevdp May 12 '23 at 03:20

0 Answers0