微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

如何保存JAX训练模型的优化器状态?

如何解决如何保存JAX训练模型的优化器状态?

我正在玩mnist_vae示例,但无法弄清楚如何正确保存/加载训练模型的权重。

enc_init_rng,dec_init_rng = random.split(random.PRNGKey(2))
_,init_encoder_params = encoder_init(enc_init_rng,(batch_size,28 * 28))
_,init_decoder_params = decoder_init(dec_init_rng,10))
init_params = init_encoder_params,init_decoder_params

opt_init,opt_update,get_params = optimizers.momentum(step_size,mass=0.9)
opt_state = opt_init(init_params)

之后,我使用opt_update训练模型并希望保存它。但是,我没有找到任何将优化器状态保存到磁盘的功能

我尝试保存参数并使用它们初始化opt_state,但并非所有信息都保存下来,结果opt_state_1不是原始的opt_state。

weights=get_params(opt_state)  
jnp.save(file,weights)  
weights = jnp.load(file,allow_pickle=True)  
opt_state_1 = opt_init(init_params)

如何正确保存我训练的模型?

解决方法

import pickle
from jax.experimental import optimizers

trained_params = optimizers.unpack_optimizer_state(opt_state)
pickle.dump(trained_params,open(os.path.join(config["ckpt_path"],"best_ckpt.pkl"),"wb"))

best_params = pickle.load(open(os.path.join(config["ckpt_path"],"rb"))
best_opt_state = optimizers.pack_optimizer_state(best_params)

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。