Search code examples
pythonmachine-learningoptimizationminimizationjax

JAX code for minimizing Lennard-Jones potential for 2 points in Python gives unexpected results


I am trying to practice using JAX fo optimization problem and I am trying to do a simple problem, which is to minimize Lennard-Jones potential for just 2 points and I set both epsilon and sigma in Lennard-Jones potential equal 1, so the potential is just: F = 4(1/r^12-1/r^6) and r is the distance between the two points. And the result should be r = 2^(1/6), which is approximately 1.12.

Using JAX, I wrote following code, which is pretty simple and short, my initial guess values for two points are [0,1], which I think it is reasonable(because for Lennard-Jones potential it could be a problem because it approach infinite if r guess is too small). As I mentioned, I am expecting a value of r around 1.12 after the minimization, however, the result I get is [-0.71276042 1.71276042], so the distance is 2.4, which is clearly too big and I am wondering how can I fix it. I original doubt it might be the precision so I change the data type to float64, but the results are still the same. Any help will be greatly appreciated! Here is my code

import jax
import jax.numpy as jnp
from jax.scipy.optimize import minimize
from jax import vmap
import matplotlib.pyplot as plt

N = 2
jax.config.update("jax_enable_x64", True)
x_init = jnp.arange(N, dtype=jnp.float64)
epsilon = 1
sigma = 1

def potential(r):
    r = jnp.where(r == 0, jnp.finfo(jnp.float64).eps, r)
    return 4 * epsilon * ((sigma/r)**12 - (sigma/r)**6)

def F(x):
    # Compute all pairwise distances
    r = jnp.abs(x[:, None] - x[None, :])
    # Compute all pairwise potentials
    pot = vmap(vmap(potential))(r)
    # Exclude the diagonal (distance = 0) and avoid double-counting by taking upper triangular part
    pot = jnp.triu(pot, 1)
    # Sum up all the potentials
    total = jnp.sum(pot)
    return total

# Minimize the function
print(F)
result = minimize(F, x_init, method='BFGS')

# Extract the optimized positions of the points
x_solutions = result.x
print(x_solutions)

Solution

  • This function is one that would be very difficult for any unconstrained gradient-based optimizer to correctly optimize. Holding one point at zero and varying the other point on the range (0, 10], we see the potential looks like this:

    r = jnp.linspace(0.1, 5.0, 1000)
    plt.plot(r, jax.vmap(lambda ri: F(jnp.array([0, ri])))(r))
    plt.ylim(-2, 10)
    

    enter image description here

    To the left of the minimum, the gradient quickly diverges to negative infinity, meaning for nearly any reasonable step size, the optimizer will likely overshoot the minimum. Then on the right side, if the optimizer goes even a few units too far, the gradient tends to zero, meaning for nearly any reasonable step size, the optimizer will get stuck in a regime where the potential has almost no variation.

    Add to this the fact that you've set up the model with two degrees of freedom in a degenerate potential, and it's not surprising that gradient-based optimization methods are failing.

    You can make some progress here by minimizing the log of the shifted potential, which has the effect of smoothing the steep gradients, and lets the BFGS minimizer find an expected minimum:

    result = minimize(lambda x: jnp.log(2 + F(x)), x_init, method='BFGS')
    print(result.x)
    # [-0.06123102  1.06123102]
    

    But in general my suggestion would probably be to opt for a constrained optimization approach instead, perhaps one of the JAXOpt constrained optimization methods, where you can rule-out problematic regions of the parameter space.