1

I am trying to practice using JAX fo optimization problem and I am trying to do a simple problem, which is to minimize Lennard-Jones potential for just 2 points and I set both epsilon and sigma in Lennard-Jones potential equal 1, so the potential is just: F = 4(1/r^12-1/r^6) and r is the distance between the two points. And the result should be r = 2^(1/6), which is approximately 1.12.

Using JAX, I wrote following code, which is pretty simple and short, my initial guess values for two points are [0,1], which I think it is reasonable(because for Lennard-Jones potential it could be a problem because it approach infinite if r guess is too small). As I mentioned, I am expecting a value of r around 1.12 after the minimization, however, the result I get is [-0.71276042 1.71276042], so the distance is 2.4, which is clearly too big and I am wondering how can I fix it. I original doubt it might be the precision so I change the data type to float64, but the results are still the same. Any help will be greatly appreciated! Here is my code

import jax
import jax.numpy as jnp
from jax.scipy.optimize import minimize
from jax import vmap
import matplotlib.pyplot as plt

N = 2
jax.config.update("jax_enable_x64", True)
x_init = jnp.arange(N, dtype=jnp.float64)
epsilon = 1
sigma = 1

def potential(r):
    r = jnp.where(r == 0, jnp.finfo(jnp.float64).eps, r)
    return 4 * epsilon * ((sigma/r)**12 - (sigma/r)**6)

def F(x):
    # Compute all pairwise distances
    r = jnp.abs(x[:, None] - x[None, :])
    # Compute all pairwise potentials
    pot = vmap(vmap(potential))(r)
    # Exclude the diagonal (distance = 0) and avoid double-counting by taking upper triangular part
    pot = jnp.triu(pot, 1)
    # Sum up all the potentials
    total = jnp.sum(pot)
    return total

# Minimize the function
print(F)
result = minimize(F, x_init, method='BFGS')

# Extract the optimized positions of the points
x_solutions = result.x
print(x_solutions)
Heng Yuan
  • 43
  • 4
  • Yeah. I understand your point, but in my code, I use x_init = jnp.arange(N, dtype=jnp.float64), which means my x-coordinate for the two points are [0,1], (no y values since I just assume it is the 1D case), and the r, which is the distance between the two points, will be just 1, not zero. – Heng Yuan May 28 '23 at 20:21
  • Oh I see what you mean - sorry! – slothrop May 28 '23 at 20:23
  • One thought: if you allow both x coordinates to vary, then there are infinitely many "optimal" solutions - anything of the form (p, p+1.12) is a solution. So this might confuse the optimiser. You could try fixing the x-coordinate of one point at 0 without loss of generality, then just optimise for the x-coord of the other point. – slothrop May 28 '23 at 20:25
  • 1
    True. I also considered that multiple degrees of freedom might "confuse" JAX. I kind of giving faith in him haha. I can try to fix the original one and see what will happen. Thank you for helping me. – Heng Yuan May 28 '23 at 20:35
  • Yea. I fixed the first point, let it just be 0, however, now the distance become just 2, which is less than 2.4 but still not 1.12. – Heng Yuan May 28 '23 at 20:44

1 Answers1

0

This function is one that would be very difficult for any unconstrained gradient-based optimizer to correctly optimize. Holding one point at zero and varying the other point on the range (0, 10], we see the potential looks like this:

r = jnp.linspace(0.1, 5.0, 1000)
plt.plot(r, jax.vmap(lambda ri: F(jnp.array([0, ri])))(r))
plt.ylim(-2, 10)

enter image description here

To the left of the minimum, the gradient quickly diverges to negative infinity, meaning for nearly any reasonable step size, the optimizer will likely overshoot the minimum. Then on the right side, if the optimizer goes even a few units too far, the gradient tends to zero, meaning for nearly any reasonable step size, the optimizer will get stuck in a regime where the potential has almost no variation.

Add to this the fact that you've set up the model with two degrees of freedom in a degenerate potential, and it's not surprising that gradient-based optimization methods are failing.

You can make some progress here by minimizing the log of the shifted potential, which has the effect of smoothing the steep gradients, and lets the BFGS minimizer find an expected minimum:

result = minimize(lambda x: jnp.log(2 + F(x)), x_init, method='BFGS')
print(result.x)
# [-0.06123102  1.06123102]

But in general my suggestion would probably be to opt for a constrained optimization approach instead, perhaps one of the JAXOpt constrained optimization methods, where you can rule-out problematic regions of the parameter space.

jakevdp
  • 77,104
  • 11
  • 125
  • 160
  • I see. Thank you so much for your help! It really fixed the issue and I appreciate your detailed explanation. Have a good weekend! – Heng Yuan May 28 '23 at 22:05
  • Btw, I want to ask another thing, which is if I increase the number of points in this system, such as if I have 10 points, there could cause an issue for minimizing the log of the shifted potential right? I did a test on N = 10 and the results look kind of weird. – Heng Yuan May 28 '23 at 22:24
  • Yeah, the log thing is a band-aid. I wouldn't expect it to work in a 10-dimensional potential – fundamentally, the problem is that you have a highly-degenerate potential with diverging gradients that does not lend itself to successful analysis via unconstrained gradient-based optimization. I don't know what solution I'd recommend in general, but I'd suggest reading up in the computational particle physics literature: I suspect this is a class of problem that researchers have written about. – jakevdp May 29 '23 at 14:27