3

I am playing with the mnist_vae example and can't figure out how to properly save/load weights of the trained model.

enc_init_rng, dec_init_rng = random.split(random.PRNGKey(2))
_, init_encoder_params = encoder_init(enc_init_rng, (batch_size, 28 * 28))
_, init_decoder_params = decoder_init(dec_init_rng, (batch_size, 10))
init_params = init_encoder_params, init_decoder_params

opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=0.9)
opt_state = opt_init(init_params)

after that, I train the model using opt_update and want to save it. However, I haven't found any function to save the optimizer state to the disk.

I tried to save parameters and initialize opt_state with them, but not all the information conserves, and the result opt_state_1 is not the original opt_state.

weights=get_params(opt_state)  
jnp.save(file, weights)  
weights = jnp.load(file,allow_pickle=True)  
opt_state_1 = opt_init(init_params)

How do I properly save the model I trained?

egorssed
  • 31
  • 1
  • 2

1 Answers1

1
import pickle
from jax.experimental import optimizers

trained_params = optimizers.unpack_optimizer_state(opt_state)
pickle.dump(trained_params, open(os.path.join(config["ckpt_path"], "best_ckpt.pkl"), "wb"))

best_params = pickle.load(open(os.path.join(config["ckpt_path"], "best_ckpt.pkl"), "rb"))
best_opt_state = optimizers.pack_optimizer_state(best_params)
  • While this code may solve the question, [including an explanation](//meta.stackexchange.com/q/114762) of how and why this solves the problem would really help to improve the quality of your post, and probably result in more up-votes. Remember that you are answering the question for readers in the future, not just the person asking now. Please [edit] your answer to add explanations and give an indication of what limitations and assumptions apply. – Adrian Mole Mar 15 '21 at 12:35