I need to minimize a simple function that divides two values. The optimization paramter x
is a (n,m)
numpy array from which I calculate a float.
# An initial value
normX0 = calculate_normX(x_start)
def objective(x) -> float:
"""Objective function """
x = x.reshape((n,m))
normX = calculate_normX(x)
return -(float(normX) / float(normX0))
def calculate_normX()
is a wrapper function to an external (Java-)API that takes the ndarray as an input and outputs a float, in this case, the norm of a vector. For the optimization, I was using jax
and jaxopt
, since it supports automatic differentiation of objective
.
solver = NonlinearCG(fun=objective, maxiter=5, verbose=True)
res = solver.run(x.flatten())
or the regular scipy minimize
objective_jac = jax.jacrev(objective)
minimize(objective, jac=objective_jac, x0=`x.flatten(), method='L-BFGS-B', options={'maxiter': 2})
In both cases, however, x
is not changed during the optimization step. Even initializing x with random values the optimizer does not seem to work. I also tried other solvers like Jaxopt NonlinearCG. What am I doing wrong?
The external, non-JAX function call is almost certainly the source of the problem. Non-JAX function calls within JAX transforms like jacrev
effectively get replaced with a trace-time constant (and in most cases will error), and so it makes sense that your optimization will not change its value.
The best approach would be to define your calculate_normX
function using JAX rather than calling out to an external API, and then everything should work automatically.
If you must call to an external API, one way to do this in JAX is to use pure_callback
along with custom_jvp
to define the autodiff rule for your external callback. There is an example of this in the JAX docs: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html#example-pure-callback-with-custom-jvp