Search code examples
pythonperformanceoptimizationjax

Slow JAX Optimization with ScipyBoundedMinimize and Optax - Seeking Speedup Strategies


I'm working on optimizing a model in jax that involves fitting a large observational dataset (4800 data points) with a complex model containing interpolation. The current optimization process using jaxopt.ScipyBoundedMinimize takes around 30 seconds for 100 iterations, with most of the time spent seemingly during or before the first iteration starts. You can find the relevant code snippet below. you can find the necessary data for the relevant code at the following link.

necessary data (idc, sg and cpcs)

import jax.numpy as jnp
import time as ela_time
from jaxopt import ScipyBoundedMinimize
import optax
import jax
import pickle


file1 = open('idc.pkl', 'rb')
idc = pickle.load(file1)
file1.close()

file2 = open('sg.pkl', 'rb')
sg = pickle.load(file2)
file2.close()

file3 = open('cpcs.pkl', 'rb')
cpcs = pickle.load(file3)
file3.close()


def model(fssc, fssh, time, rv, amp):

    fssp = 1.0 - (fssc + fssh)

    ivis = cpcs['common'][time]['ivis']
    areas = cpcs['common'][time]['areas']
    mus = cpcs['common'][time]['mus']

    vels = idc['vels'].copy()

    ldfs_phot = cpcs['line'][time]['ldfs_phot']
    ldfs_cool = cpcs['line'][time]['ldfs_cool']
    ldfs_hot = cpcs['line'][time]['ldfs_hot']

    lps_phot = cpcs['line'][time]['lps_phot']
    lps_cool = cpcs['line'][time]['lps_cool']
    lps_hot = cpcs['line'][time]['lps_hot']

    lis_phot = cpcs['line'][time]['lis_phot']
    lis_cool = cpcs['line'][time]['lis_cool']
    lis_hot = cpcs['line'][time]['lis_hot']

    coeffs_phot = lis_phot * ldfs_phot * areas * mus
    wgt_phot = coeffs_phot * fssp[ivis]
    wgtn_phot = jnp.sum(wgt_phot)

    coeffs_cool = lis_cool * ldfs_cool * areas * mus
    wgt_cool = coeffs_cool * fssc[ivis]
    wgtn_cool = jnp.sum(wgt_cool)

    coeffs_hot = lis_hot * ldfs_hot * areas * mus
    wgt_hot = coeffs_hot * fssh[ivis]
    wgtn_hot = jnp.sum(wgt_hot)

    prf = jnp.sum(wgt_phot[:, None] * lps_phot + wgt_cool[:, None] * lps_cool + wgt_hot[:, None] * lps_hot, axis=0)
    prf /= wgtn_phot + wgtn_cool + wgtn_hot

    prf = jnp.interp(vels, vels + rv, prf)

    prf = prf + amp

    avg = jnp.mean(prf)

    prf = prf / avg

    return prf


def loss(x0s, lmbd):

    noes = sg['noes']

    noo = len(idc['times'])

    fssc = x0s[:noes]
    fssh = x0s[noes: 2 * noes]
    fssp = 1.0 - (fssc + fssh)
    rv = x0s[2 * noes: 2 * noes + noo]
    amp = x0s[2 * noes + noo: 2 * noes + 2 * noo]

    chisq = 0
    for i, itime in enumerate(idc['times']):
        oprf = idc['data'][itime]['prf']
        oprf_errs = idc['data'][itime]['errs']

        nop = len(oprf)

        sprf = model(fssc=fssc, fssh=fssh, time=itime, rv=rv[i], amp=amp[i])

        chisq += jnp.sum(((oprf - sprf) / oprf_errs) ** 2) / (noo * nop)

    wp = sg['grid_areas'] / jnp.max(sg['grid_areas'])

    mem = jnp.sum(wp * (fssc * jnp.log(fssc / 1e-5) + fssh * jnp.log(fssh / 1e-5) +
                    (1.0 - fssp) * jnp.log((1.0 - fssp) / (1.0 - 1e-5)))) / sg['noes']

    ftot = chisq + lmbd * mem

    return ftot


if __name__ == '__main__':

    # idc: a dictionary containing observational data (150 x 32)
    # sg and cpcs: dictionaries with related coefficients

    noes = sg['noes']
    lmbd = 1.0
    maxiter = 1000
    tol = 1e-5

    fss = jnp.ones(2 * noes) * 1e-5
    x0s = jnp.hstack((fss, jnp.zeros(len(idc['times']) * 2)))

    minx0s = [1e-5] * (2 * noes) + [-jnp.inf] * len(idc['times']) * 2
    maxx0s = [1.0 - 1e-5] * (2 * noes) + [jnp.inf] * len(idc['times']) * 2

    bounds = (minx0s, maxx0s)

    start = ela_time.time()

    optimizer = ScipyBoundedMinimize(fun=loss, maxiter=maxiter, tol=tol, method='L-BFGS-B',
                                 options={'disp': True})
    x0s, info = optimizer.run(x0s, bounds,  lmbd)

    # optimizer = optax.adam(learning_rate=0.1)
    # optimizer_state = optimizer.init(x0s)
    #
    # for i in range(1, maxiter + 1):
    #
    #     print('ITERATION -->', i)
    #
    #     gradients = jax.grad(loss)(x0s, lmbd)
    #     updates, optimizer_state = optimizer.update(gradients, optimizer_state, x0s)
    #     x0s = optax.apply_updates(x0s, updates)
    #     x0s = jnp.clip(x0s, jnp.array(minx0s), jnp.array(maxx0s))
    #     print('Objective function: {:.3E}'.format(loss(x0s, lmbd)))

    end = ela_time.time()

    print(end - start)   # total elapsed time: ~30 seconds

Here's a breakdown of the relevant aspects:

  • Number of free parameters (x0s): 5263
  • Data: Observational data stored in idc dictionary (4800 data points)
  • Model: Defined in model function, also utilizes interpolation
  • Optimization methods tried:
    • jaxopt.ScipyBoundedMinimize with L-BFGS-B method (slow ~30 seconds, with most of the time spent during or just before the first iteration)
    • optax.adam (too slow ~200 seconds)
  • Attempted parallelization: I attempted to parallelize optax.adam, yet due to the inherent nature of the modeling, I couldn't succeed as the x0s couldn't be divided. (assuming I understood parallelization correctly)

Questions:

  1. What are potential reasons for the slowness before or during the first iteration in ScipyBoundedMinimize ?
  2. Are there alternative optimization algorithms in jax that might be faster for my scenario (large number of free parameters and data points, complex model with interpolation)?
  3. Did I misunderstand parallelization with optax.adam? Are there any strategies for potential parallelization in this case?
  4. Are there any code optimizations within the provided snippet that could improve performance (e.g., vectorization)?

Additional Information:

  • Hardware: Intel® Core™ i7-9750H CPU @ 2.60GHz × 12, 16 GiB RAM (laptop)
  • Software: OS Ubuntu 22.04, Python 3.10.12, JAX 0.4.25, optax 0.2.1

I'd appreciate any insights or suggestions to improve the optimization performance.


Solution

  • JAX code is Just-in-time (JIT) compiled, meaning that the long duration of the first step is likely related to compilation costs. The longer your code is, the more time it will take to compile.

    One common issue leading to long compile times is the use of Python control flow such as for loops. JAX's tracing machinery essentially flattens out these loops (see JAX Sharp Bits: Control Flow). In your case, you loop over 4800 entries in your data structure, and thus are creating a very long and inefficient program.

    The typical solution in a case like this is to rewrite your program using jax.vmap. Like most JAX constructs, this works best with a struct-of-arrays pattern rather than the array-of-structs pattern used in your data. So the first step to using vmap is to restructure your data in a way that JAX can use; it might look something like this:

    itimes = jnp.arange(len(idc['times']))
    prf = jnp.array([idc['data'][i]['prf'] for i in itimes])
    errs = jnp.array([idc['data'][i]['errs'] for i in itimes])
    
    sprf = jax.vmap(model, in_axes=[None, None, 0, 0, 0])(fssc, fssh, itimes, rv, amp)
    chi2 = jnp.sum((oprf - sprf) / oprf_errs) ** 2) / len(times) / sprf.shape[1]
    

    This will not work directly: you'll also have to restructure the data used by your model function into the struct-of-arrays style, but hopefully this gives you the general idea.

    Note also that this assumes that every entry of idc['data'][i]['prf'] and idc['data'][i]['errs'] has the same shape. If that's not the case, then I'm afraid your problem is not particularly well-suited to JAX's SPMD programming model, and there's not an easy way to work around the need for long compilations.