微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

使用 XLNet 进行情感分析 - 设置正确的重塑参数

如何解决使用 XLNet 进行情感分析 - 设置正确的重塑参数

this link 之后,我尝试使用自己的数据进行情绪分析。但我收到此错误

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<timed exec> in <module>

<ipython-input-41-5f2f35b7976e> in train_epoch(model,data_loader,optimizer,device,scheduler,n_examples)
      7 
      8     for d in data_loader:
----> 9         input_ids = d["input_ids"].reshape(4,64).to(device)
     10         attention_mask = d["attention_mask"].to(device)
     11         targets = d["targets"].to(device)

RuntimeError: shape '[4,64]' is invalid for input of size 64

当我尝试运行此代码

history = defaultdict(list)
best_accuracy = 0

for epoch in range(EPOCHS):
    print(f'Epoch {epoch + 1}/{EPOCHS}')
    print('-' * 10)

    train_acc,train_loss = train_epoch(
        model,train_data_loader,len(df_train)
    )

    print(f'Train loss {train_loss} Train accuracy {train_acc}')

    val_acc,val_loss = eval_model(
        model,val_data_loader,len(df_val)
    )

    print(f'Val loss {val_loss} Val accuracy {val_acc}')
    print()

    history['train_acc'].append(train_acc)
    history['train_loss'].append(train_loss)
    history['val_acc'].append(val_acc)
    history['val_loss'].append(val_loss)

我知道这个错误与我的数据的形状有关,但我不确定如何找到正确的 reshape 参数来完成这项工作。

解决方法

你的例子中的形状 [4,64] 实际上是 [batch size,max_sequence_length]

所以也许你可以用你的价值观替换它们......

,

然而,您还没有发布您的示例数据,但很明显您是如何使用您的 reshape function。关于您将 reshape d["input_ids"] 变成 (4,64) 的问题,那么 d["input_ids"] 应该是一个大小为 256 的数组,但实际上在您提供给模型大小为 64

因此,您需要根据数据的形式(其倍数为 64)使用 d["input_ids"] 之类的东西重塑 (1,64) or (2,32) or (4,16)

只是说明相同:

>>> a = np.arange(256).reshape(4,64)
>>> a
array([[  0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63],[ 64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127],[128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143,144,145,146,147,148,149,150,151,152,153,154,155,156,157,158,159,160,161,162,163,164,165,166,167,168,169,170,171,172,173,174,175,176,177,178,179,180,181,182,183,184,185,186,187,188,189,190,191],[192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207,208,209,210,211,212,213,214,215,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237,238,239,240,241,242,243,244,245,246,247,248,249,250,251,252,253,254,255]])

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。