Search code examples
jax

How to compute the number of gradient evaluations in Jax.Scipy.minimize.optimize?


I wish to obtain the total number of gradient evaluations during optimization using jax.scipy.optimize.minimize. How can I do so?


Solution

  • You can find this in the njev (number of Jacobian evaluations) attribute of the output of minimize. For example:

    import jax.numpy as jnp
    from jax.scipy.optimize import minimize
    
    def f(x):
      return jnp.sum(x ** 2)
    
    out = minimize(f, jnp.array([1., 2.]), method="BFGS")
    print(out.njev)
    
    3
    

    You can find a list of the information available in the optimization output in the documentation for OptimizeResults.