微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

从检查点创建Estimator并另存为SavedModel,无需进一步培训

如何解决从检查点创建Estimator并另存为SavedModel,无需进一步培训

我已经从TF Slim resnet V2检查点创建了一个估算器,并对其进行了测试以进行预测。我所做的主要工作基本上与普通的Estimator相似,并且与Assign_from_checkpoint_fn一样:

def model_fn(features,labels,mode,params):
  ...
  slim.assign_from_checkpoint_fn(os.path.join(checkpoint_dir,'resnet_v2_50.ckpt'),slim.get_model_variables('resnet_v2')
  ...
  if mode == tf.estimator.ModeKeys.PREDICT:
    predictions = {
      'class_ids': predicted_classes[:,tf.newaxis],'probabilities': tf.nn.softmax(logits),'logits': logits,}
  return tf.estimator.EstimatorSpec(mode,predictions=predictions)

要将估算器导出为SavedModel,我进行了以下serving_input_fn:

def image_preprocess(image_buffer):
    image = tf.image.decode_jpeg(image_buffer,channels=3)
    image_preprocessing_fn = preprocessing_factory.get_preprocessing('inception',is_training=False)
    image = image_preprocessing_fn(image,FLAGS.image_size,FLAGS.image_size)
    return image

def serving_input_fn():
    input_ph = tf.placeholder(tf.string,shape=[None],name='image_binary')
    image_tensors = image_preprocess(input_ph)
    return tf.estimator.export.ServingInputReceiver(image_tensors,input_ph)

在主函数中,我使用export_saved_model尝试将Estimator导出为SavedModel格式:

def main():
    ...
    classifier = tf.estimator.Estimator(model_fn=model_fn)
    classifier.export_saved_model(dir_path,serving_input_fn)

但是,当我尝试运行代码时,它显示“在/ tmp / tmpn3spty2z找不到经过训练的模型”。据我了解,这个export_saved_model试图找到训练有素的Estimator模型以导出到SavedModel。但是,我想知道是否可以通过任何方法将经过预训练的检查点还原到Estimator中,并将Estimator导出到SavedModel而不进行任何进一步的培训?

解决方法

我已经解决了我的问题。要将带有TF 1.14的TF Slim Resnet检查点导出到SavedModel,可以将热启动与export_savedmodel一起使用,如下所示:

config = tf.estimator.RunConfig(save_summary_steps = None,save_checkpoints_secs = None)
warm_start = tf.estimator.WarmStartSettings(checkpoint_dir,checkpoint_name)
classifier = tf.estimator.Estimator(model_fn=model_fn,warm_start_from = warm_start,config = config)
classifier.export_savedmodel(export_dir_base = FLAGS.output_dir,serving_input_receiver_fn =  serving_input_fn)

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?