Search code examples
pythonoptimizationscipy-optimize-minimizejax

Python function minimization not changing optimization variables


I need to minimize a simple function that divides two values. The optimization paramter x is a (n,m) numpy array from which I calculate a float.

# An initial value
normX0 = calculate_normX(x_start)

def objective(x) -> float:
    """Objective function """
    x = x.reshape((n,m))
    normX = calculate_normX(x)
    return -(float(normX) / float(normX0))

def calculate_normX() is a wrapper function to an external (Java-)API that takes the ndarray as an input and outputs a float, in this case, the norm of a vector. For the optimization, I was using jax and jaxopt, since it supports automatic differentiation of objective.

solver = NonlinearCG(fun=objective, maxiter=5, verbose=True)
res = solver.run(x.flatten())

or the regular scipy minimize

objective_jac = jax.jacrev(objective)
minimize(objective, jac=objective_jac, x0=`x.flatten(), method='L-BFGS-B', options={'maxiter': 2})

In both cases, however, x is not changed during the optimization step. Even initializing x with random values the optimizer does not seem to work. I also tried other solvers like Jaxopt NonlinearCG. What am I doing wrong?


Solution

  • The external, non-JAX function call is almost certainly the source of the problem. Non-JAX function calls within JAX transforms like jacrev effectively get replaced with a trace-time constant (and in most cases will error), and so it makes sense that your optimization will not change its value.

    The best approach would be to define your calculate_normX function using JAX rather than calling out to an external API, and then everything should work automatically.

    If you must call to an external API, one way to do this in JAX is to use pure_callback along with custom_jvp to define the autodiff rule for your external callback. There is an example of this in the JAX docs: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html#example-pure-callback-with-custom-jvp