I'm working on optimizing a model in jax
that involves fitting a large observational dataset (4800 data points) with a complex model containing interpolation. The current optimization process using jaxopt.ScipyBoundedMinimize
takes around 30 seconds for 100 iterations, with most of the time spent seemingly during or before the first iteration starts. You can find the relevant code snippet below. you can find the necessary data for the relevant code at the following link.
necessary data (idc, sg and cpcs)
import jax.numpy as jnp
import time as ela_time
from jaxopt import ScipyBoundedMinimize
import optax
import jax
import pickle
file1 = open('idc.pkl', 'rb')
idc = pickle.load(file1)
file1.close()
file2 = open('sg.pkl', 'rb')
sg = pickle.load(file2)
file2.close()
file3 = open('cpcs.pkl', 'rb')
cpcs = pickle.load(file3)
file3.close()
def model(fssc, fssh, time, rv, amp):
fssp = 1.0 - (fssc + fssh)
ivis = cpcs['common'][time]['ivis']
areas = cpcs['common'][time]['areas']
mus = cpcs['common'][time]['mus']
vels = idc['vels'].copy()
ldfs_phot = cpcs['line'][time]['ldfs_phot']
ldfs_cool = cpcs['line'][time]['ldfs_cool']
ldfs_hot = cpcs['line'][time]['ldfs_hot']
lps_phot = cpcs['line'][time]['lps_phot']
lps_cool = cpcs['line'][time]['lps_cool']
lps_hot = cpcs['line'][time]['lps_hot']
lis_phot = cpcs['line'][time]['lis_phot']
lis_cool = cpcs['line'][time]['lis_cool']
lis_hot = cpcs['line'][time]['lis_hot']
coeffs_phot = lis_phot * ldfs_phot * areas * mus
wgt_phot = coeffs_phot * fssp[ivis]
wgtn_phot = jnp.sum(wgt_phot)
coeffs_cool = lis_cool * ldfs_cool * areas * mus
wgt_cool = coeffs_cool * fssc[ivis]
wgtn_cool = jnp.sum(wgt_cool)
coeffs_hot = lis_hot * ldfs_hot * areas * mus
wgt_hot = coeffs_hot * fssh[ivis]
wgtn_hot = jnp.sum(wgt_hot)
prf = jnp.sum(wgt_phot[:, None] * lps_phot + wgt_cool[:, None] * lps_cool + wgt_hot[:, None] * lps_hot, axis=0)
prf /= wgtn_phot + wgtn_cool + wgtn_hot
prf = jnp.interp(vels, vels + rv, prf)
prf = prf + amp
avg = jnp.mean(prf)
prf = prf / avg
return prf
def loss(x0s, lmbd):
noes = sg['noes']
noo = len(idc['times'])
fssc = x0s[:noes]
fssh = x0s[noes: 2 * noes]
fssp = 1.0 - (fssc + fssh)
rv = x0s[2 * noes: 2 * noes + noo]
amp = x0s[2 * noes + noo: 2 * noes + 2 * noo]
chisq = 0
for i, itime in enumerate(idc['times']):
oprf = idc['data'][itime]['prf']
oprf_errs = idc['data'][itime]['errs']
nop = len(oprf)
sprf = model(fssc=fssc, fssh=fssh, time=itime, rv=rv[i], amp=amp[i])
chisq += jnp.sum(((oprf - sprf) / oprf_errs) ** 2) / (noo * nop)
wp = sg['grid_areas'] / jnp.max(sg['grid_areas'])
mem = jnp.sum(wp * (fssc * jnp.log(fssc / 1e-5) + fssh * jnp.log(fssh / 1e-5) +
(1.0 - fssp) * jnp.log((1.0 - fssp) / (1.0 - 1e-5)))) / sg['noes']
ftot = chisq + lmbd * mem
return ftot
if __name__ == '__main__':
# idc: a dictionary containing observational data (150 x 32)
# sg and cpcs: dictionaries with related coefficients
noes = sg['noes']
lmbd = 1.0
maxiter = 1000
tol = 1e-5
fss = jnp.ones(2 * noes) * 1e-5
x0s = jnp.hstack((fss, jnp.zeros(len(idc['times']) * 2)))
minx0s = [1e-5] * (2 * noes) + [-jnp.inf] * len(idc['times']) * 2
maxx0s = [1.0 - 1e-5] * (2 * noes) + [jnp.inf] * len(idc['times']) * 2
bounds = (minx0s, maxx0s)
start = ela_time.time()
optimizer = ScipyBoundedMinimize(fun=loss, maxiter=maxiter, tol=tol, method='L-BFGS-B',
options={'disp': True})
x0s, info = optimizer.run(x0s, bounds, lmbd)
# optimizer = optax.adam(learning_rate=0.1)
# optimizer_state = optimizer.init(x0s)
#
# for i in range(1, maxiter + 1):
#
# print('ITERATION -->', i)
#
# gradients = jax.grad(loss)(x0s, lmbd)
# updates, optimizer_state = optimizer.update(gradients, optimizer_state, x0s)
# x0s = optax.apply_updates(x0s, updates)
# x0s = jnp.clip(x0s, jnp.array(minx0s), jnp.array(maxx0s))
# print('Objective function: {:.3E}'.format(loss(x0s, lmbd)))
end = ela_time.time()
print(end - start) # total elapsed time: ~30 seconds
Here's a breakdown of the relevant aspects:
x0s
): 5263idc
dictionary (4800 data points)model
function, also utilizes interpolationjaxopt.ScipyBoundedMinimize
with L-BFGS-B
method (slow ~30 seconds, with most of the time spent during or just before the first iteration)optax.adam
, yet due to the inherent nature of the modeling, I couldn't succeed as the x0s
couldn't be divided. (assuming I understood parallelization correctly)Questions:
ScipyBoundedMinimize
?jax
that might be faster for my scenario (large number of free parameters and data points, complex model with interpolation)?optax.adam
? Are there any strategies for potential parallelization in this case?Additional Information:
I'd appreciate any insights or suggestions to improve the optimization performance.
JAX code is Just-in-time (JIT) compiled, meaning that the long duration of the first step is likely related to compilation costs. The longer your code is, the more time it will take to compile.
One common issue leading to long compile times is the use of Python control flow such as for
loops. JAX's tracing machinery essentially flattens out these loops (see JAX Sharp Bits: Control Flow). In your case, you loop over 4800 entries in your data structure, and thus are creating a very long and inefficient program.
The typical solution in a case like this is to rewrite your program using jax.vmap
. Like most JAX constructs, this works best with a struct-of-arrays pattern rather than the array-of-structs pattern used in your data. So the first step to using vmap
is to restructure your data in a way that JAX can use; it might look something like this:
itimes = jnp.arange(len(idc['times']))
prf = jnp.array([idc['data'][i]['prf'] for i in itimes])
errs = jnp.array([idc['data'][i]['errs'] for i in itimes])
sprf = jax.vmap(model, in_axes=[None, None, 0, 0, 0])(fssc, fssh, itimes, rv, amp)
chi2 = jnp.sum((oprf - sprf) / oprf_errs) ** 2) / len(times) / sprf.shape[1]
This will not work directly: you'll also have to restructure the data used by your model
function into the struct-of-arrays style, but hopefully this gives you the general idea.
Note also that this assumes that every entry of idc['data'][i]['prf']
and idc['data'][i]['errs']
has the same shape. If that's not the case, then I'm afraid your problem is not particularly well-suited to JAX's SPMD programming model, and there's not an easy way to work around the need for long compilations.