I'm trying to use the bfgs optimizer from tensorflow_probability.substrates.jax
and from jax.scipy.optimize.minimize
to minimize a function f
which is estimated from pseudo-random samples and has a jax.random.PRNGKey
as argument. To use this function with the jax/tfp bfgs minimizer, I wrap the function inside a lambda function
seed = 100
key = jax.random.PRNGKey(seed)
fun = lambda x: return f(x,key)
result = jax.scipy.optimize.minimize(fun = fun, ...)
What is the best way to update the key when the minimization routine calls the function to be minimized so that I use different pseudo-random numbers in a reproducible way? Maybe a global key variable? If yes, is there an example I could follow?
Secondly, is there a way to make the optimization stop after a certain amount of time, as one could do with a callback in scipy? I could directly use the scipy implementation of bfgs/ l-bfgs-b/ etc and use jax ony for the estimation of the function and of tis gradients, which seems to work. Is there a difference between the scipy, jax.scipy and tfp.jax bfgs implementations?
Finally, is there a way to print the values of the arguments of fun
during the bfgs optimization in jax.scipy or tfp, given that f
is jitted?
Thank you!
There is no way to do what you're asking with jax.scipy.optimize.minimize
, because the minimizer does not offer any means to track changing state between function calls, and does not provide for any inbuilt stochasticity in the optimizer.
If you're interested in stochastic optimization in JAX, you might try stochastic optimization in JAXOpt, which provides a much more flexible set of optimization routines.
Regarding your second question, if you'd like to print values during the course of a jit-compiled optimization or other loop, you can use jax.debug.print
.