微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

Tensorflow:如何加载/恢复完整模型.pb、.pbtxt、.graph、.ckpt、变量、事件?

如何解决Tensorflow:如何加载/恢复完整模型.pb、.pbtxt、.graph、.ckpt、变量、事件?

我最近从 PyTorch 切换到 TF(1 和 2),我正在尝试使用它获得一个良好的工作流程。

我想做的简单事情如下:

  1. TF1 zooTF2 zoo 加载完整的预训练对象检测模型
  2. 使用 model.summary() 检查加载模型的网络架构。
  3. 使用预训练的加载模型进行推理。
  4. 修改(例如重塑、删除添加)加载模型的层和权重。
  5. 重新训练修改后的加载模型。

我知道 TF 有图和权重的概念,而不是 PyTorch,后者只有包含一切的模型。 尽管如此,我还是找不到加载预训练模型的简单且最佳的方法,而且互联网上充斥着针对不同 tf 版本的不同答案。

我真的很困惑,因为当我从 TF1 zoo(或 TF2 zoo)下载预训练模型时,为了实现上述几点,我有很多不同的文件可用。

this one为例,TF1动物园列表中的第一个我有 saved_model 文件夹,其中包含 saved_model.pbvariables(空)文件夹,frozen_inference_graph.pb model.ckpt 文件pipeline.config 和某些情况下是 event 文件。 所有这些不同的文件真的需要对图形结构和权重进行编码吗?我是否遗漏了什么或者这只是比必要的更复杂? 此外,如果您从 TF2 zoo 下载模型,文件/文件夹结构会有所不同(见下图)

enter image description here

我的尝试

import tensorflow as tf #(v2.4)
def load_pretrained_model(self,saved_model_sub_folder,mode):
    # 1. this only load an AutoTrackable object that can be use for inference but no graph
    if mode == '.pb':
        model_dir = str(TRAINED_MODEL_DIR) + saved_model_sub_folder
        model_dir = pathlib.Path(model_dir) / "saved_model"
        model = tf.saved_model.load(str(model_dir),None,'.')
        detection_model = model.signatures['serving_default']

    # 2. this returns None
    elif mode == '.graph':
        def load_graph(frozen_graph_filename):
            with tf.compat.v1.gfile.GFile(frozen_graph_filename,"rb") as f:
                graph_def = tf.compat.v1.GraphDef()
                graph_def.ParseFromString(f.read())
                return graph_def
        detection_model = tf.compat.v1.import_graph_def(load_graph(frozen_graph_filename))
    else:
        detection_model = None

    return detection_model

Tl;博士

有人可以回答上面关于如何在 python3 中加载完整(图形、权重、一切..)可定制 tensorflow1 或 tensorflow2 模型的一些要点(1 到 5)?

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。