Search code examples
pythonoptimizationscipyvectorizationjax

vectorized minimization and root finding in jax


I have a family of functions parameterized by args

f(x, args)

and want to determine the minimum of f over x for N = 1000 values of args. I have access to both the function and its derivative. My first attempt was to loop through the different values of args and use a scipy.optimizer at each iteration, but it takes too long. I believe the operations can be sped up with vectorization. My next attempt was to use jax.vmap inside a jax.scipy.optimize.minimize or jaxopt.ScipyMinimize, but I can't seem to pass more than one value for args.

Alternatively, I can code my own vectorized optimization method, e.g. bisection, where by vectorized I mean doing operations on arrays for a fixed number of iterations and not stopping early if one of the optimization problems has reached a certain error tolerance level early. I was hoping to use some optimized off-shelf algorithm.

I was hoping to use some already optimized, off-the-shelf algorithm if an implementation is available in jax.this thread is related, but the args are not changing.


Solution

  • You can define a function to find the minimum given particular args, and then wrap it in jax.vmap to automatically vectorize it. For example:

    import jax
    import jax.numpy as jnp
    from jax.scipy import optimize
    
    def f(x, args):
      a, b = args
      return jnp.sum(a + (x - b) ** 2)
    
    def find_min(a, b):
      x0 = jnp.array([1.0])
      args = (a, b)
      return optimize.minimize(f, x0, (args,), method="BFGS")
    
    a_grid, b_grid = jnp.meshgrid(jnp.arange(5.0), jnp.arange(5.0))
    
    results = jax.vmap(find_min)(a_grid.ravel(), b_grid.ravel())
    
    print(results.success)
    # [ True  True  True  True  True  True  True  True  True  True  True  True
    #   True  True  True  True  True  True  True  True  True  True  True  True
    #   True]
    
    print(results.x.T)
    # [[0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 2. 2. 2. 2. 2.
    #   3. 3. 3. 3. 3. 4. 4. 4. 4. 4.]]