导出至iOS时,Python和LibTorch C ++之间的输出不一致

如何解决导出至iOS时,Python和LibTorch C ++之间的输出不一致

我已经为我的数据训练了HuggingFace RoBERTa模型(这是一种非常特殊的用法-因此是小模型/词汇!),并在Python上成功进行了测试。我将跟踪模型导出到iOS的LibTorch,但是设备上的预测结果与Python中的预测结果不匹配(给出了不同的argmax令牌索引)。我的转换脚本:

# torch = 1.5.0
# transformers = 3.2.0

config = RobertaConfig(
    vocab_size=858,max_position_embeddings=258,num_attention_heads=6,num_hidden_layers=4,type_vocab_size=1,torchscript=True,)

model = RobertaForMaskedLM(config=config).from_pretrained('./trained_RoBERTa')
model.cpu()
model.eval()

example_input = torch.LongTensor(1,256).random_(0,857).cpu()
traced_model = torch.jit.trace(model,example_input)
traced_model.save('./exports/trained_RoBERTa.pt')

过去,我遇到了另一个问题,该问题是我在Python + GPU中训练并转换为iOS的LibTorch的(视觉)模型,通过在我的map_location={'cuda:0': 'cpu'}调用添加torch.load()解决转换脚本。因此,我想知道是否:1)在这种情况下可以作为一种合理的解释?,以及2)在使用map_location语法加载时如何添加.from_pretrained()选项?

以防我对预测结果的Obj-C ++处理受到指责,以下是在设备上运行的Obj-C ++代码

- (NSArray<NSArray<NSNumber*>*>*)predictText:(NSArray<NSNumber*>*)tokenIDs {
    try {
        long count = tokenIDs.count;
        long* buffer = new long[count];
        for(int i=0; i < count;  i++) {
            buffer[i] = tokenIDs[i].intValue;
        }
        at::Tensor tensor = torch::from_blob(buffer,{1,(int64_t)count},at::kLong);
        torch::autograd::AutoGradMode guard(false);
        at::AutoNonVariableTypeMode non_var_type_mode(true);
        auto outputTuple = _impl.forward({tensor}).toTuple();

        auto outputTensor = outputTuple->elements()[0].toTensor();
        auto sizes = outputTensor.sizes();
        // len will be tokens * vocab size -- sizes[1] * sizes[2] (sizes[0] is batch_size = 1)
        auto positions = sizes[1];
        auto tokens = sizes[2];
        float* floatBuffer = outputTensor.data_ptr<float>();
        if (!floatBuffer) {
            return nil;
        }
        // MARK: This is probably a slow way to create this 2D NSArray
        NSMutableArray* results = [[NSMutableArray alloc] initWithCapacity: positions];
        for (int i = 0; i < positions; i++) {
            NSMutableArray* weights = [[NSMutableArray alloc] initWithCapacity: tokens];
            for (int j = 0; j < tokens; j++) {
                [weights addobject:@(floatBuffer[i*positions + j])];
            }
            [results addobject:weights];
        }
        return [results copy];
    } catch (const std::exception& exception) {
        NSLog(@"%s",exception.what());
    }
    return nil;
}

请注意,我在iOS中的初始化代码确实在TorchScript模型上调用eval()

更新:一项观察;当我在上面加载经过训练的模型时尝试使用config的方式导致未设置torchscript标志-我假设它完全忽略了我的config并从预训练中获取文件。因此,如文档所述,我已将其移至from_pretrained('./trained_RoBERTa',torchscript=True)。请注意,iOS上的输出也存在同样的问题。

更新2:我想我会尝试在Python中测试跟踪模型。不确定是否可以正常工作,但是输出是否与原始模型中的相同测试相匹配:

traced_test = traced_model(input)
pred = torch.argmax(traced_test[0],dim=2).squeeze(0)
pred_str = tokenizer.decode(pred[1:-1].tolist())
print(pred_str)

这让我觉得iOS Obj-C ++执行中有一些问题。加载跟踪的模型/导出的代码确实在模型上调用.eval(),btw(我意识到这可能是对不同输出的可能解释):

- (nullable instancetype)initWithFileAtPath:(Nsstring*)filePath {
    self = [super init];
    if (self) {
        try {
            auto qengines = at::globalContext().supportedQEngines();
            if (std::find(qengines.begin(),qengines.end(),at::QEngine::QNNPACK) != qengines.end()) {
                at::globalContext().setQEngine(at::QEngine::QNNPACK);
            }
            _impl = torch::jit::load(filePath.UTF8String);
            _impl.eval();
        } catch (const std::exception& exception) {
            NSLog(@"%s",exception.what());
            return nil;
        }
    }
    return self;
}

更新3:Uhhhmmm ...这绝对是一个面对面的时刻(在一个浪费的周末之后)...我决定从Obj-C返回一个平坦的NSArray并在Swift中对2D数组进行整形,除了转移一个令牌(我认为这只是[CLS]),现在输出正确。我想我的Obj-C确实生锈。遗憾的是,我仍然看不到该问题,但是它现在正在运行,所以我要投降。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?
Java在半透明框架/面板/组件上重新绘画。
Java“ Class.forName()”和“ Class.forName()。newInstance()”之间有什么区别?
在此环境中不提供编译器。也许是在JRE而不是JDK上运行?
Java用相同的方法在一个类中实现两个接口。哪种接口方法被覆盖?
Java 什么是Runtime.getRuntime()。totalMemory()和freeMemory()?
java.library.path中的java.lang.UnsatisfiedLinkError否*****。dll
JavaFX“位置是必需的。” 即使在同一包装中
Java 导入两个具有相同名称的类。怎么处理?
Java 是否应该在HttpServletResponse.getOutputStream()/。getWriter()上调用.close()?
Java RegEx元字符(。)和普通点?