如何解决从检查点创建Estimator并另存为SavedModel,无需进一步培训
我已经从TF Slim resnet V2检查点创建了一个估算器,并对其进行了测试以进行预测。我所做的主要工作基本上与普通的Estimator相似,并且与Assign_from_checkpoint_fn一样:
def model_fn(features,labels,mode,params):
...
slim.assign_from_checkpoint_fn(os.path.join(checkpoint_dir,'resnet_v2_50.ckpt'),slim.get_model_variables('resnet_v2')
...
if mode == tf.estimator.ModeKeys.PREDICT:
predictions = {
'class_ids': predicted_classes[:,tf.newaxis],'probabilities': tf.nn.softmax(logits),'logits': logits,}
return tf.estimator.EstimatorSpec(mode,predictions=predictions)
要将估算器导出为SavedModel,我进行了以下serving_input_fn:
def image_preprocess(image_buffer):
image = tf.image.decode_jpeg(image_buffer,channels=3)
image_preprocessing_fn = preprocessing_factory.get_preprocessing('inception',is_training=False)
image = image_preprocessing_fn(image,FLAGS.image_size,FLAGS.image_size)
return image
def serving_input_fn():
input_ph = tf.placeholder(tf.string,shape=[None],name='image_binary')
image_tensors = image_preprocess(input_ph)
return tf.estimator.export.ServingInputReceiver(image_tensors,input_ph)
在主函数中,我使用export_saved_model尝试将Estimator导出为SavedModel格式:
def main():
...
classifier = tf.estimator.Estimator(model_fn=model_fn)
classifier.export_saved_model(dir_path,serving_input_fn)
但是,当我尝试运行代码时,它显示“在/ tmp / tmpn3spty2z找不到经过训练的模型”。据我了解,这个export_saved_model试图找到训练有素的Estimator模型以导出到SavedModel。但是,我想知道是否可以通过任何方法将经过预训练的检查点还原到Estimator中,并将Estimator导出到SavedModel而不进行任何进一步的培训?
解决方法
我已经解决了我的问题。要将带有TF 1.14的TF Slim Resnet检查点导出到SavedModel,可以将热启动与export_savedmodel一起使用,如下所示:
config = tf.estimator.RunConfig(save_summary_steps = None,save_checkpoints_secs = None)
warm_start = tf.estimator.WarmStartSettings(checkpoint_dir,checkpoint_name)
classifier = tf.estimator.Estimator(model_fn=model_fn,warm_start_from = warm_start,config = config)
classifier.export_savedmodel(export_dir_base = FLAGS.output_dir,serving_input_receiver_fn = serving_input_fn)
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。