微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

如何在GCP AI平台上使用TFRecord文件进行批量预测?

如何解决如何在GCP AI平台上使用TFRecord文件进行批量预测?

TL; DR 做批量预测时,Google Cloud AI平台如何解压缩TFRecord文件

我已经将训练有素的Keras模型部署到Google Cloud AI平台上,但是在批量预测的文件格式方面遇到了麻烦。为了进行培训,我使用tf.data.TFRecordDataset来读取TFRecord的列表,如下所示,都可以正常工作。

def unpack_tfrecord(record):
    parsed = tf.io.parse_example(record,{
        'chunk': tf.io.FixedLenFeature([128,2,3],tf.float32),# Input
        'class': tf.io.FixedLenFeature([2],tf.int64),# One-hot classification (binary)
    })

    return (parsed['chunk'],parsed['class'])

files = [str(p) for p in training_chunks_path.glob('*.tfrecord')]
dataset = tf.data.TFRecordDataset(files).batch(32).map(unpack_tfrecord)
model.fit(x=dataset,epochs=train_epochs)
tf.saved_model.save(model,model_save_path)

我将保存的模型上传到Cloud Storage并在AI Platform中创建一个新模型。 AI平台文档指出“带有gcloud工具的批处理[支持]带有JSON实例字符串的文本文件或TFRecord文件(可以压缩)”(https://cloud.google.com/ai-platform/prediction/docs/overview#prediction_input_data)。但是当我提供TFRecord文件时,出现错误

("'utf-8' codec can't decode byte 0xa4 in position 1: invalid start byte",8)

我的TFRecord文件包含一堆Protobuf编码的tf.train.Example。我没有为AI平台提供unpack_tfrecord函数,所以我认为它无法正确解压缩是有道理的,但是我知道节点的位置。我对使用JSON格式不感兴趣,因为数据太大。

解决方法

我不知道这是否是解决此问题的最佳方法,但对于TF 2.x,您可以执行以下操作:

import tensorflow as tf

def make_serving_input_fn():
    # your feature spec
    feature_spec = {
        'chunk': tf.io.FixedLenFeature([128,2,3],tf.float32),'class': tf.io.FixedLenFeature([2],tf.int64),}

    serialized_tf_examples = tf.keras.Input(
        shape=[],name='input_example_tensor',dtype=tf.string)

    examples = tf.io.parse_example(serialized_tf_examples,feature_spec)

    # any processing 
    processed_chunks = tf.map_fn(
        <PROCESSING_FN>,examples['chunk'],# ?
        dtype=tf.float32)

    return tf.estimator.export.ServingInputReceiver(
        features={<MODEL_FIRST_LAYER_NAME>: processed_chunks},receiver_tensors={"input_example_tensor": serialized_tf_examples}
    )


estimator = tf.keras.estimator.model_to_estimator(
    keras_model=model,model_dir=<ESTIMATOR_SAVE_DIR>)

estimator.export_saved_model(
    export_dir_base=<WORKING_DIR>,serving_input_receiver_fn=make_serving_input_fn)

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。