微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

Torch JIT Trace = TracerWarning:将张量转换为 Python 布尔值可能会导致跟踪不正确

如何解决Torch JIT Trace = TracerWarning:将张量转换为 Python 布尔值可能会导致跟踪不正确

我正在学习本教程:https://huggingface.co/transformers/torchscript.html 创建我的自定义 BERT 模型的跟踪,但是在运行完全相同的 dummy_input 时,我收到一个错误

TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. 
We cant record the data flow of Python values,so this value will be treated as a constant in the future. 

在我的模型和标记器中加载后,创建跟踪的代码如下:

text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = tokenizer.tokenize(text)

# Masking one of the input tokens
masked_index = 8
tokenized_text[masked_index] = '[MASK]'
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
segments_ids = [0,1,1]

tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])
dummy_input = [tokens_tensor,segments_tensors]

traced_model = torch.jit.trace(model,dummy_input)

dummy_input 是张量列表,所以我不确定 Boolean 类型在这里发挥作用。有没有人明白为什么会发生这个错误以及是否发生了布尔转换?

非常感谢

解决方法

这个错误意味着什么

当您尝试torch.jit.trace具有数据相关控制流的模型时,会发生此警告

这个简单的例子应该会更清楚:

import torch


class Foo(torch.nn.Module):
    def forward(self,tensor):
        # It is data dependent
        # Trace will only work with one path
        if tensor.max() > 0.5:
            return tensor ** 2
        return tensor


model = Foo()
traced = torch.jit.script(model) # No warnings
traced = torch.jit.trace(model,torch.randn(10)) # Warning

本质上,BERT 模型有一些依赖于数据的控制流(如 iffor 循环),因此您会收到警告。

警告自身

您可以看到 BERT forward 代码 here

如果:

  • 参数不会改变(比如传递给 Noneforward 值),它会在 script 之后保持这种状态(例如在推理调用期间)
  • 如果有基于 __init__ 内部收集的数据的控制流(如配置),因为这不会改变

例如:

elif input_ids is not None:
    input_shape = input_ids.size()
    batch_size,seq_length = input_shape

只会作为一个带有 torch.jit.trace 的分支运行,因为它只是跟踪张量上的操作,并且不知道像这样的控制流。

HuggingFace 团队可能已经意识到这一点,并且此警告不是问题(尽管您可能会仔细检查您的用例或尝试使用 torch.jit.script

跟着torch.jit.script

这会很困难,因为整个模型必须torchscript兼容(torchscript 有 Python 的一个子集可用,而且很可能无法使用 BERT 开箱即用)。>

仅在必要时才这样做(可能不是)。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。