tensorflow保存模型、恢复模型

1、模型训练(部分代码):

X = tf.placeholder(tf.float64,X_data.shape,name='X')
Y = tf.placeholder(tf.float64,Y_data.shape,name='Y')
epoch_num = 500
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    loss_data = []
    # 创建FileWriter对象,用当前计算图初始化
    writer = tf.summary.FileWriter('./summary/', sess.graph)

    # 保存模型
    saver_path = './model/checkpoint/model.ckpt' # 模型保存路径
    saver = tf.train.Saver() # 新建Saver()对象

    for i in range(1,epoch_num+1):
        _, loss = sess.run([optimizer,loss_func],feed_dict={X:X_data,Y:Y_data})
        loss_data.append(loss)
        saved_path = saver.save(sess, saver_path) # 保存模型
        print("epoch:%d,loss:%.4g" % (i,loss))
    # 关闭FileWriter
    writer.close()

 

2、保存模型

# 模型保存路径
saver_path = './model/checkpoint/model.ckpt' 
# 新建Saver()对象
saver = tf.train.Saver()
# 保存模型
saved_path = saver.save(sess, saver_path)

执行之后,在目录./model/checkpoint/model.ckpt下,生成模型相关文件,如图:

 

3、恢复模型并使用模型、变量

meta_path = './model/checkpoint/model.ckpt.meta'
model_path = './model/checkpoint/model.ckpt'
# 导入计算图
saver = tf.train.import_meta_graph(meta_path)
config = tf.ConfigProto()
with tf.Session(config=config) as sess:
    # 恢复模型
    saver.restore(sess, model_path)
    # 此时默认图就是导入的图
    graph_restore = tf.get_default_graph()
    # 恢复变量
    W = graph_restore.get_tensor_by_name('W:0')
    b = graph_restore.get_tensor_by_name('b:0')
    # 预测模型
    predict_func = tf.matmul(test_data, W)
    predict_value = sess.run([predict_func],feed_dict={x:test_data})

 

 

 

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 [email protected] 举报,一经查实,本站将立刻删除。

相关推荐