微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

如何将TensorRT SavedModel加载到TensorFlow Estimator?

如何解决如何将TensorRT SavedModel加载到TensorFlow Estimator?

我正在使用TensorFlow 1.14,并将TensorFlow SavedModel加载到Estimator中,以下代码我有用:

estimator = tf.contrib.estimator.SavedModelEstimator(saved_model_dir)
prediction_results = estimator.predict(input_fn)

但是,当我使用TensorRT将TensorFlow SavedModel转换为TensorRT SavedModel时,它返回一条错误消息:

ValueError: Directory provided has an invalid SavedModel format: saved_models/nvidia_fp16_converted

我已经对其进行了进一步的研究,问题似乎在于TensorRT在SavedModel目录中没有生成任何变量信息(包括variables.index),这导致了上述错误。有人知道如何解决这个问题吗?

解决方法

对于任何有兴趣的人,以下是我自己想出的解决方案: 通常,可以使用以下命令将TF SavedModel加载到Estimator:

estimator = tf.contrib.estimator.SavedModelEstimator(SAVED_MODEL_DIR)

但是,由于TensorRT将所有变量都转换为常量,因此在加载TensorRT时出现SavedModel错误,因此SavedModel目录中没有变量的信息(例如,没有variables.index)→由于Estimator尝试加载变量,因此不会发生错误。解决问题的步骤:

  • 我们需要转到文件"/usr/local/lib/python3.6/dist-packages/tensorflow_estimator/python/estimator/estimator.py",line 2330,in _get_saved_model_ckpt并注释掉对variable.index的检查
"""if not gfile.Exists(os.path.join(saved_model_utils.get_variables_dir(saved_model_dir),compat.as_text('variables.index'))):
raise ValueError('Directory provided has an invalid SavedModel format: %s'
% saved_model_dir)"""
  • 转到文件"/usr/local/lib/python3.6/dist-packages/tensorflow_estimator/python/estimator/canned/saved_model_estimator.py",line 145,in __init__checkpoint_utils.list_variables(checkpoint)]并进行更改,以使Estimator不会尝试从SavedModel加载变量:
"""checkpoint = estimator_lib._get_saved_model_ckpt(saved_model_dir) # pylint: disable=protected-access
vars_to_warm_start = [name for name,_ in
checkpoint_utils.list_variables(checkpoint)]
warm_start_settings = estimator_lib.WarmStartSettings(
ckpt_to_initialize_from=checkpoint,vars_to_warm_start=vars_to_warm_start)"""
warm_start_settings = estimator_lib.WarmStartSettings(ckpt_to_initialize_from = estimator_lib._get_saved_model_ckpt(saved_model_dir))
  • 转到文件:"/usr/local/lib/python3.6/dist-packages/tensorflow_estimator/python/estimator/canned/saved_model_estimator.py",line 256,in _model_fn_from_saved_modeltraining_util.assert_global_step(global_step_tensor),并注释掉“ global_step”的检查内容,以防模型是从NVIDIA示例生成的(因此无需进行任何培训并且未设置“ global_step”):
#global_step_tensor = training_util.get_global_step(g)
#training_util.assert_global_step(global_step_tensor)
  • 转到文件:"/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/checkpoint_utils.py",line 291,in init_from_checkpoint init_from_checkpoint_fn),然后将return放在init_from_checkpoint函数的开头,这样它就不会尝试加载检查点:
def _init_from_checkpoint(ckpt_dir_or_file,assignment_map):
"""See `init_from_checkpoint` for documentation."""
return

在完成所有上述更改之后,加载过程应该会很好。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。