0

I'm trying to use the bfgs optimizer from tensorflow_probability.substrates.jax and from jax.scipy.optimize.minimize to minimize a function f which is estimated from pseudo-random samples and has a jax.random.PRNGKey as argument. To use this function with the jax/tfp bfgs minimizer, I wrap the function inside a lambda function

seed = 100
key  = jax.random.PRNGKey(seed)
fun = lambda x: return f(x,key)
result = jax.scipy.optimize.minimize(fun = fun, ...)

What is the best way to update the key when the minimization routine calls the function to be minimized so that I use different pseudo-random numbers in a reproducible way? Maybe a global key variable? If yes, is there an example I could follow?

Secondly, is there a way to make the optimization stop after a certain amount of time, as one could do with a callback in scipy? I could directly use the scipy implementation of bfgs/ l-bfgs-b/ etc and use jax ony for the estimation of the function and of tis gradients, which seems to work. Is there a difference between the scipy, jax.scipy and tfp.jax bfgs implementations?

Finally, is there a way to print the values of the arguments of fun during the bfgs optimization in jax.scipy or tfp, given that f is jitted?

Thank you!

Dan Leonte
  • 27
  • 4

1 Answers1

1

There is no way to do what you're asking with jax.scipy.optimize.minimize, because the minimizer does not offer any means to track changing state between function calls, and does not provide for any inbuilt stochasticity in the optimizer.

If you're interested in stochastic optimization in JAX, you might try stochastic optimization in JAXOpt, which provides a much more flexible set of optimization routines.

Regarding your second question, if you'd like to print values during the course of a jit-compiled optimization or other loop, you can use jax.debug.print.

jakevdp
  • 77,104
  • 11
  • 125
  • 160
  • Thank you for your answer @jakevdp. What if the stochasticity comes from estimating the gradients with samples via the reparameterization trick? More precisely, $D$ is a function of $\theta$ in $$ \min_{x} \mathbb{E}_{D}[f(x, \theta, D)] $$ – Dan Leonte Jan 25 '23 at 17:06
  • I'm not sure what that means. – jakevdp Jan 25 '23 at 17:22
  • 1
    JAX's minimizer is defined assuming a stateless (pure) deterministic function `f` that takes a vector `x` and outputs a scalar. If you can express your function that way, then you can use `jax.scipy.optimize.minimize`. If not, you'll have to find another tool. – jakevdp Jan 25 '23 at 17:24
  • in my case, `f` is an expectation, which I estimate with samples, hence the stochasticity. I do not have batches of data to be loaded in an iterator, just one deterministic dataset. I think that the stochastic optimizers you recommended would to the trick, as in `rng_seq = hk.PRNGSequence(FLAGS.random_seed)` from [https://jaxopt.github.io/stable/auto_examples/deep_learning/haiku_vae.html](VAE) – Dan Leonte Jan 25 '23 at 21:16
  • yet I was hoping to use a quasi-newton method. these don't seem to be available in the stochastic optimization library. do you think using the `next()` method of a `hk.PRNGSequence` inside my `f` would allow the use of `jax.scipy.optimize.minimize`, or would I still have to use the stoch. optimizers? thank you so much for your time and answers! – Dan Leonte Jan 25 '23 at 21:22
  • 1
    No, I don't believe that will work, because it relies on an external iterator state. JAX's primitives would allow you to write a minimizer that lets `func` pass along an updated internal state, but `jax.scipy.optimize.minimize` is not written to allow that. The change you'd need to make is to add the PRNG key to the state here: https://github.com/google/jax/blob/78599e65d11a418bd81c7763a3d4fdab960c7bc0/jax/_src/scipy/optimize/bfgs.py#L153-L161, and then update it each time the function is called. – jakevdp Jan 25 '23 at 22:18