I am trying to use jax and python to solve df/dz = 0 for z1,z2...zn. However, it seems my code is not working because all I get is zero(which is the initial guess I put in) I am witing this code as an exercise to get more familiar with Jax. I am trying to find all the points on a sphere. The code is attach. Am I using minimize function and Jax correctly? Any helps and advice will be greatly appreciated! (G and G2 are just two 3*n matrix containing coordinates and the adjacent_point is just an aray which includes neiborgers)
def distance_3d(p1, p2):
x1, y1, z1 = p1
x2, y2, z2 = p2
return jnp.sqrt((x2 - x1)**2 + (y2 - y1)**2 + (z2 - z1)**2)
def spher_potential_computation_numerical(G, G2, adjacent_points):
N = len(G)
def F(z):
total = 0.0
for i in range(N):
p1 = (G[i, 0], G[i, 1], z[i])
p3 = (G2[i, 0], G2[i, 1], G[i, 2])
for j in adjacent_points[i]:
p2 = (G[j, 0], G[j, 1], z[j])
p4 = (G2[j, 0], G2[j, 1], G[j, 2])
total += (distance_3d(p1, p2))**2- (distance_3d(p3, p4)**2)**2
return total
dF_dz = jit(grad(F)) # Calculate the gradient using JAX's automatic differentiation
# Initial guess for the z-values
z_guess = jnp.zeros(N)
def equations(z_vals):
return jnp.sum(jnp.abs(dF_dz(z_vals)))
# Solve the system of PDEs numerically
res = minimize(equations, z_guess, method='BFGS')
z_solutions = res.x
# Update the z-values in the mesh
G = G.at[:, 2].set(z_solutions)
return G
The results is the same as initial guess, even if I change the initial guess to all 1