I have a family of functions parameterized by args
f(x, args)
and want to determine the minimum of f
over x
for N = 1000
values of args
. I have access to both the function and its derivative. My first attempt was to loop through the different values of args
and use a scipy.optimizer at each iteration, but it takes too long. I believe the operations can be sped up with vectorization. My next attempt was to use jax.vmap
inside a jax.scipy.optimize.minimize
or jaxopt.ScipyMinimize
, but I can't seem to pass more than one value for args
.
Alternatively, I can code my own vectorized optimization method, e.g. bisection, where by vectorized I mean doing operations on arrays for a fixed number of iterations and not stopping early if one of the optimization problems has reached a certain error tolerance level early. I was hoping to use some optimized off-shelf algorithm.
I was hoping to use some already optimized, off-the-shelf algorithm if an implementation is available in jax.this thread is related, but the args
are not changing.
You can define a function to find the minimum given particular args
, and then wrap it in jax.vmap
to automatically vectorize it. For example:
import jax
import jax.numpy as jnp
from jax.scipy import optimize
def f(x, args):
a, b = args
return jnp.sum(a + (x - b) ** 2)
def find_min(a, b):
x0 = jnp.array([1.0])
args = (a, b)
return optimize.minimize(f, x0, (args,), method="BFGS")
a_grid, b_grid = jnp.meshgrid(jnp.arange(5.0), jnp.arange(5.0))
results = jax.vmap(find_min)(a_grid.ravel(), b_grid.ravel())
print(results.success)
# [ True True True True True True True True True True True True
# True True True True True True True True True True True True
# True]
print(results.x.T)
# [[0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 2. 2. 2. 2. 2.
# 3. 3. 3. 3. 3. 4. 4. 4. 4. 4.]]