I am trying to practice using JAX fo optimization problem and I am trying to do a simple problem, which is to minimize Lennard-Jones potential for just 2 points and I set both epsilon and sigma in Lennard-Jones potential equal 1, so the potential is just: F = 4(1/r^12-1/r^6) and r is the distance between the two points. And the result should be r = 2^(1/6), which is approximately 1.12.
Using JAX, I wrote following code, which is pretty simple and short, my initial guess values for two points are [0,1], which I think it is reasonable(because for Lennard-Jones potential it could be a problem because it approach infinite if r guess is too small). As I mentioned, I am expecting a value of r around 1.12 after the minimization, however, the result I get is [-0.71276042 1.71276042], so the distance is 2.4, which is clearly too big and I am wondering how can I fix it. I original doubt it might be the precision so I change the data type to float64, but the results are still the same. Any help will be greatly appreciated! Here is my code
import jax
import jax.numpy as jnp
from jax.scipy.optimize import minimize
from jax import vmap
import matplotlib.pyplot as plt
N = 2
jax.config.update("jax_enable_x64", True)
x_init = jnp.arange(N, dtype=jnp.float64)
epsilon = 1
sigma = 1
def potential(r):
r = jnp.where(r == 0, jnp.finfo(jnp.float64).eps, r)
return 4 * epsilon * ((sigma/r)**12 - (sigma/r)**6)
def F(x):
# Compute all pairwise distances
r = jnp.abs(x[:, None] - x[None, :])
# Compute all pairwise potentials
pot = vmap(vmap(potential))(r)
# Exclude the diagonal (distance = 0) and avoid double-counting by taking upper triangular part
pot = jnp.triu(pot, 1)
# Sum up all the potentials
total = jnp.sum(pot)
return total
# Minimize the function
print(F)
result = minimize(F, x_init, method='BFGS')
# Extract the optimized positions of the points
x_solutions = result.x
print(x_solutions)