I am translating some of my R codes to Python as a learning process, especially trying JAX
for autodiff.
In functions to implement non-linear least square, when I set tolerance at 1e-8, the estimated parameters are nearly identical after several iterations, but the algorithm never appear to converge.
However, the R codes converge at the 12th inter at tol=1e-8 and 14th inter at tol=1e-9. The estimated parameters are almost the same as the ones resulted from Python implementation.
I think this has something to do with floating point, but not sure which step I could improve to make the converge as quickly as seen in R.
Here are my codes, and most steps are the same as in R
import jax
import jax.numpy as jnp
import numpy as np
import scipy.linalg as ola
def update_parm(X, y, fun, dfun, parm, theta, wt):
len_y = len(y)
mean_fun = fun(X, parm)
if (type(wt) == bool):
if (wt):
var_fun = np.exp(theta * np.log(mean_fun))
sqrtW = 1 / np.sqrt(var_fun ** 2)
else:
sqrtW = 1
else:
sqrtW = wt
gradX = dfun(x, parm)
weighted_X = sqrtW.reshape(len_y, 1) * gradX
z = gradX @ parm + (y - mean_fun)
weighted_z = sqrtW * z
qr_gradX = ola.qr(weighted_X, mode="economic")
Q = qr_gradX[0]
R = qr_gradX[1]
new_parm = ola.solve(R, np.dot(Q.T, weighted_z))
return new_parm
def nls_irwls(X, y, fun, dfun, init, theta = 1, tol = 1e-8, maxiter = 500):
old_parm = init
iter = 0
while (iter < maxiter):
new_parm = update_parm(X, y, fun, dfun, parm=old_parm, theta=theta, wt=True)
parm_diff = np.max(np.abs(new_parm - old_parm) / np.abs(old_parm))
print(parm_diff)
if (parm_diff < tol) :
break
else:
old_parm = new_parm
iter += 1
print(new_parm)
if (iter == maxiter):
print("The algorithm failed to converge")
else:
return {"Estimated coefficient": new_parm}
x = np.array([0.25, 0.5, 0.75, 1, 1.25, 2, 3, 4, 5, 6, 8])
y = np.array([2.05, 1.04, 0.81, 0.39, 0.30, 0.23, 0.13, 0.11, 0.08, 0.10, 0.06])
def model(x, W):
comp1 = jnp.exp(W[0])
comp2 = jnp.exp(-jnp.exp(W[1]) * x)
comp3 = jnp.exp(W[2])
comp4 = jnp.exp(-jnp.exp(W[3]) * x)
return comp1 * comp2 + comp3 * comp4
init = np.array([0.69, 0.69, -1.6, -1.6])
#autodiff
model_grad = jax.jit(jax.jacfwd(model, argnums=1))
#manual derivative
def dModel(x, W):
e1 = np.exp(W[1])
e2 = np.exp(W[3])
e5 = np.exp(-(x * e1))
e6 = np.exp(-(x * e2))
e7 = np.exp(W[0])
e8 = np.exp(W[2])
b1 = e5 * e7
b2 = -(x * e5 * e7 * e1)
b3 = e6 * e8
b4 = -(x * e6 * e8 * e2)
return np.array([b1, b2, b3, b4]).T
nls_irwls(x, y, model, model_grad, init=init, theta=1, tol=1e-8, maxiter=50)
nls_irwls(x, y, model, dModel, init=init, theta=1, tol=1e-8, maxiter=50)
One thing to be aware of is that by default, JAX performs computations in 32-bit, while tools like R and numpy perform computations in 64-bit. Since 1E-8
is at the edge of 32-bit floating point precision, I suspect this is why your program is failing to converge.
You can enable 64-bit computation by putting this at the beginning of your script:
from jax import config
config.update('jax_enable_x64', True)
After doing this, your program converges as expected. For more information, see JAX Sharp Bits: Double Precision.