I wish to obtain the total number of gradient evaluations during optimization using jax.scipy.optimize.minimize. How can I do so?
You can find this in the njev
(number of Jacobian evaluations) attribute of the output of minimize
. For example:
import jax.numpy as jnp
from jax.scipy.optimize import minimize
def f(x):
return jnp.sum(x ** 2)
out = minimize(f, jnp.array([1., 2.]), method="BFGS")
print(out.njev)
3
You can find a list of the information available in the optimization output in the documentation for OptimizeResults
.