微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

来自 TF Hub 临时文件的微调 BERT SavedModel 硬引用 vocab.txt

如何解决来自 TF Hub 临时文件的微调 BERT SavedModel 硬引用 vocab.txt

首先,我对 TensorFlow 比较陌生,所以可能我做错了什么。我正在尝试为二进制文本分类创建一个独立的 TensorFlow 模型。我想创建一个可以通过 TensorFlow ServingAWS SageMaker 加载的 SavedModel。独立我的意思是我希望它是端到端的:将文本提供给端点,得到一个浮点数作为响应。

我已经开始在我的机器上本地使用 Classify text with BERT notebook,并将 electra_small 模型作为微调的基本模型。我已经在本地保存了微调的模型,将它加载回来并确保它可以工作。我关闭了我的电脑,第二天我发现我不能再加载我的模型了。原来它正在寻找一个临时的 tfhub 文件,该文件不再存在,导致 /tmp/tfhub_modules/09bd4e665682e6f03bc72fbcff7a68bf6879910e/assets/vocab.txt; No such file or directory 错误。我已经从 TF hub 重新下载了这些模型,模型开始加载并且工作正常。我删除/tmp/tfhub_modules/ 目录,模型再次停止加载。我检查了我保存微调模型的目录,它有与 TF Hub 临时文件相同内容vocab.txt 文件,但由于某种原因拒绝使用它。我希望 SavedModel 是自包含的,并使用它自己的本地词汇文件,而不是从 TF Hub 的外部临时目录中硬引用词汇文件

我可以在网上找到的大多数教程都在进行 word_ids + masks -> model 类型的分类,使 text -> word_ids + masks 成为预处理步骤的一部分。我真的不喜欢这个想法,因为我想让 API 消费者完全忘记任何类型的 NLP 相关的东西。我想保留它“文本输入,概率输出”。如果可能的话,我宁愿不为 TensorFlow Serving API 编写包装器来处理预处理。我希望找到并修改 BERT EN uncased preprocess 的源代码,但没有找到。 IE。我在 TensorFlow 的 GitHub 上找不到 bert_pack_inputs 的定义,这是一个有点独特的函数名称My issue

代码示例:

import tensorflow_hub as hub
import tensorflow_text
import tensorflow as tf


def build_classifier_model():
    tfhub_handle_encoder = 'https://tfhub.dev/google/electra_small/2'
    tfhub_handle_preprocess = 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2'

    text_input = tf.keras.layers.Input(shape=(),dtype=tf.string,name='text')
    preprocessing_layer = hub.KerasLayer(tfhub_handle_preprocess,name='preprocessing')
    encoder_inputs = preprocessing_layer(text_input)
    encoder = hub.KerasLayer(tfhub_handle_encoder,trainable=True,name='BERT_encoder')
    outputs = encoder(encoder_inputs)
    net = outputs['pooled_output']
    net = tf.keras.layers.Dropout(0.1)(net)
    net = tf.keras.layers.Dense(1,activation=None,name='classifier')(net)
    return tf.keras.Model(text_input,net)


model_save_path = './modelname'

# Run this one first,it will create a saved model in your working dir.

# if __name__ == '__main__':
#     classifier_model = build_classifier_model()
#     bert_raw_result = classifier_model(tf.constant(['some dummy text']))
#     print(bert_raw_result)
#     classifier_model.save(model_save_path,include_optimizer=False)


# Run this one next,this should work.

if __name__ == '__main__':
    classifier_model = tf.saved_model.load(model_save_path)
    bert_raw_result = classifier_model(tf.constant(['some dummy text']))
    print(bert_raw_result)

# Now,try to remove / rename `/tmp/tfhub_modules/` directory and run the code above again. It fails for me with exception I provided in OP.
# So there's a hard reference to `/tmp/tfhub_modules/09bd4e665682e6f03bc72fbcff7a68bf6879910e/assets/vocab.txt` file instead of the one in "modelname" saved model.

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。