如何解决Tensorflow 2 对象检测 API - 将 .ckpt 文件转换为 .pb/任何保存的模型格式,以实现简单的网络应用程序部署
我正在尝试使用来自 TensorFlow 模型动物园的 ssd_mobilenet_v2_fpnlite_320x320_coco17_tpu-8 模型使用 TensorFlow 对象检测 API 进行对象检测。我可以使用 ckpt 文件(另存为 ckpt-17.data-00000-of-00001)检测单个测试图像
我需要将此 ckpt 转换为一些已保存的模型文件 (.pb/.h5) 以在简单的 Flask web 应用程序中使用。
我发现使用以下代码很难将 ckpt 文件转换为 .pb
print("""python {}/research/object_detection/export_inference_graph.py \
--input_type image_tensor \
--pipeline_config_path={}/{}/pipeline.config \
--trained /content/RealTimeObjectDetectionStages/Tensorflow/workspace/models/my_ssd_mobnet/ckpt-17.index \
--output_directory /content/RealTimeObjectDetectionStages/Tensorflow/workspace/models""".format(APIMODEL_PATH,MODEL_PATH,CUSTOM_MODEL_NAME,CUSTOM_MODEL_NAME))
!python /content/RealTimeObjectDetectionStages/Tensorflow/models/research/object_detection/export_inference_graph.py --input_type image_tensor --pipeline_config_path=/content/RealTimeObjectDetectionStages/Tensorflow/workspace/models/my_ssd_mobnet/pipeline.config --trained /content/RealTimeObjectDetectionStages/Tensorflow/workspace/models/my_ssd_mobnet/ckpt-17.ckpt --output_directory /content/RealTimeObjectDetectionStages/Tensorflow/workspace/models
当我尝试转换为 .pb 时出现以下错误
2021-03-12 11:25:57.212865: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0 Traceback (most recent call last): File "/usr/local/lib/python3.7/dist-packages/absl/flags/_flagvalues.py",line 541,in _assert_validators validator.verify(self) File "/usr/local/lib/python3.7/dist-packages/absl/flags/_validators.py",line 82,in verify raise _exceptions.ValidationError(self.message) absl.flags._exceptions.ValidationError: Flag --trained_checkpoint_prefix must have a value other than None. During handling of the above exception,another exception occurred: Traceback (most recent call last): File "/content/RealTimeObjectDetectionStages/Tensorflow/models/research/object_detection/export_inference_graph.py",line 206,in tf.app.run() File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/platform/app.py",line 40,in run _run(main=main,argv=argv,flags_parser=_parse_flags_tolerate_undef) File "/usr/local/lib/python3.7/dist-packages/absl/app.py",line 294,in run flags_parser,File "/usr/local/lib/python3.7/dist-packages/absl/app.py",line 363,in _run_init flags_parser=flags_parser,line 213,in _register_and_parse_flags_with_usage args_to_main = flags_parser(original_argv) File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/platform/app.py",line 31,in _parse_flags_tolerate_undef return flags.FLAGS(_sys.argv if argv is None else argv,known_only=True) File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/platform/flags.py",line 113,in __call__ return self.__dict__['__wrapped'].__call__(*args,**kwargs) File "/usr/local/lib/python3.7/dist-packages/absl/flags/_flagvalues.py",line 649,in __call__ self.validate_all_flags() File "/usr/local/lib/python3.7/dist-packages/absl/flags/_flagvalues.py",line 523,in validate_all_flags self._assert_validators(all_validators) File "/usr/local/lib/python3.7/dist-packages/absl/flags/_flagvalues.py",line 544,in _assert_validators raise _exceptions.IllegalFlagValueError('%s: %s' % (message,str(e))) absl.flags._exceptions.IllegalFlagValueError: flag --trained_checkpoint_prefix=None: Flag --trained_checkpoint_prefix must have a value other than None.
如何保存部署检查点?
我正在使用来自 google colab 的 TensorFlow 2
http://stackoverflow.com/questions/tagged/tensorflow-model-garden
解决方法
更正文件路径格式解决了问题:
!python /content/RealTimeObjectDetectionStages/Tensorflow/models/research/object_detection/model_main_tf2.py --model_dir=/content/RealTimeObjectDetectionStages/Tensorflow/workspace/models/my_ssd_mobnet --pipeline_config_path=/content/RealTimeObjectDetectionStages/Tensorflow/workspace/models/my_ssd_mobnet/pipeline.config --num_train_steps=1000
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。