1、模型训练(部分代码):
X = tf.placeholder(tf.float64,X_data.shape,name='X') Y = tf.placeholder(tf.float64,Y_data.shape,name='Y') epoch_num = 500 with tf.Session() as sess: sess.run(tf.global_variables_initializer()) loss_data = [] # 创建FileWriter对象,用当前计算图初始化 writer = tf.summary.FileWriter('./summary/', sess.graph) # 保存模型 saver_path = './model/checkpoint/model.ckpt' # 模型保存路径 saver = tf.train.Saver() # 新建Saver()对象 for i in range(1,epoch_num+1): _, loss = sess.run([optimizer,loss_func],feed_dict={X:X_data,Y:Y_data}) loss_data.append(loss) saved_path = saver.save(sess, saver_path) # 保存模型 print("epoch:%d,loss:%.4g" % (i,loss)) # 关闭FileWriter writer.close()
2、保存模型
# 模型保存路径 saver_path = './model/checkpoint/model.ckpt' # 新建Saver()对象 saver = tf.train.Saver() # 保存模型 saved_path = saver.save(sess, saver_path)
执行之后,在目录./model/checkpoint/model.ckpt下,生成模型相关文件,如图:
3、恢复模型并使用模型、变量
meta_path = './model/checkpoint/model.ckpt.meta' model_path = './model/checkpoint/model.ckpt' # 导入计算图 saver = tf.train.import_meta_graph(meta_path) config = tf.ConfigProto() with tf.Session(config=config) as sess: # 恢复模型 saver.restore(sess, model_path) # 此时默认图就是导入的图 graph_restore = tf.get_default_graph() # 恢复变量 W = graph_restore.get_tensor_by_name('W:0') b = graph_restore.get_tensor_by_name('b:0') # 预测模型 predict_func = tf.matmul(test_data, W) predict_value = sess.run([predict_func],feed_dict={x:test_data})
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 [email protected] 举报,一经查实,本站将立刻删除。