I am playing with the mnist_vae example and can't figure out how to properly save/load weights of the trained model.
enc_init_rng, dec_init_rng = random.split(random.PRNGKey(2))
_, init_encoder_params = encoder_init(enc_init_rng, (batch_size, 28 * 28))
_, init_decoder_params = decoder_init(dec_init_rng, (batch_size, 10))
init_params = init_encoder_params, init_decoder_params
opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=0.9)
opt_state = opt_init(init_params)
after that, I train the model using opt_update and want to save it. However, I haven't found any function to save the optimizer state to the disk.
I tried to save parameters and initialize opt_state with them, but not all the information conserves, and the result opt_state_1 is not the original opt_state.
weights=get_params(opt_state)
jnp.save(file, weights)
weights = jnp.load(file,allow_pickle=True)
opt_state_1 = opt_init(init_params)
How do I properly save the model I trained?