微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

ValueError: 目标尺寸 (torch.Size([8])) 必须与输入尺寸相同 (torch.Size([8, 2]))

如何解决ValueError: 目标尺寸 (torch.Size([8])) 必须与输入尺寸相同 (torch.Size([8, 2]))

我正在尝试使用 BERT 实现一个用于情绪分析(正面或负面标签)的代码,我想添加一个 BiLSTM 层,看看我是否可以提高 HuggingFace 预训练模型的准确性。我有以下代码和几个问题:

import numpy as np
import pandas as pd
from sklearn import metrics
import transformers
import torch
from torch.utils.data import Dataset,DataLoader,RandomSampler,SequentialSampler
from transformers import BertTokenizer,BertModel,BertConfig
from torch import cuda
import re
import torch.nn as nn

device = 'cuda' if cuda.is_available() else 'cpu'
MAX_LEN = 200
TRAIN_BATCH_SIZE = 8
VALID_BATCH_SIZE = 4
EPOCHS = 1
LEARNING_RATE = 1e-05 #5e-5,3e-5 or 2e-5
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

class CustomDataset(Dataset):
 def __init__(self,dataframe,tokenizer,max_len):
  self.tokenizer = tokenizer
  self.data = dataframe
  self.comment_text = dataframe.review
  self.targets = self.data.sentiment
  self.max_len = max_len
 def __len__(self):
  return len(self.comment_text)
 def __getitem__(self,index):
  comment_text = str(self.comment_text[index])
  comment_text = " ".join(comment_text.split())

  inputs = self.tokenizer.encode_plus(comment_text,None,add_special_tokens=True,max_length=self.max_len,pad_to_max_length=True,return_token_type_ids=True)
  ids = inputs['input_ids']
  mask = inputs['attention_mask']
  token_type_ids = inputs["token_type_ids"]

  return {
   'ids': torch.tensor(ids,dtype=torch.long),'mask': torch.tensor(mask,'token_type_ids': torch.tensor(token_type_ids,'targets': torch.tensor(self.targets[index],dtype=torch.float)
  }
train_size = 0.8
train_dataset=df.sample(frac=train_size,random_state=200)
test_dataset=df.drop(train_dataset.index).reset_index(drop=True)
train_dataset = train_dataset.reset_index(drop=True)

print("FULL Dataset: {}".format(df.shape))
print("TRAIN Dataset: {}".format(train_dataset.shape))
print("TEST Dataset: {}".format(test_dataset.shape))

training_set = CustomDataset(train_dataset,MAX_LEN)
testing_set = CustomDataset(test_dataset,MAX_LEN)
train_params = {'batch_size': TRAIN_BATCH_SIZE,'shuffle': True,'num_workers': 0}
test_params = {'batch_size': VALID_BATCH_SIZE,'num_workers': 0}
training_loader = DataLoader(training_set,**train_params)
testing_loader = DataLoader(testing_set,**test_params)


class BERTClass(torch.nn.Module):
 def __init__(self):
   super(BERTClass,self).__init__()
   self.bert = BertModel.from_pretrained('bert-base-uncased',return_dict=False,num_labels =2)
   self.lstm = nn.LSTM(768,256,batch_first=True,bidirectional=True)
   self.linear = nn.Linear(256*2,2)

 def forward(self,ids,mask,token_type_ids):
  sequence_output,pooled_output = self.bert(ids,attention_mask=mask,token_type_ids = token_type_ids)
  lstm_output,(h,c) = self.lstm(sequence_output)  ## extract the 1st token's embeddings
  hidden = torch.cat((lstm_output[:,-1,:256],lstm_output[:,256:]),dim=-1)
  linear_output = self.linear(lstm_output[:,-1].view(-1,256 * 2))

  return linear_output

model = BERTClass()
model.to(device)
print(model)
def loss_fn(outputs,targets):
 return torch.nn.BCEWithLogitsLoss()(outputs,targets)
optimizer = torch.optim.Adam(params =  model.parameters(),lr=LEARNING_RATE)

def train(epoch):
 model.train()
 for _,data in enumerate(training_loader,0):
  ids = data['ids'].to(device,dtype=torch.long)
  mask = data['mask'].to(device,dtype=torch.long)
  token_type_ids = data['token_type_ids'].to(device,dtype=torch.long)
  targets = data['targets'].to(device,dtype=torch.float)
  outputs = model(ids,token_type_ids)
  optimizer.zero_grad()
  loss = loss_fn(outputs,targets)
  if _ % 5000 == 0:
   print(f'Epoch: {epoch},Loss:  {loss.item()}')
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

for epoch in range(EPOCHS):
  train(epoch)

所以在上面的代码中我遇到了错误Target size (torch.Size([8])) must be the same as input size (torch.Size([8,2]))。在线检查并尝试使用 targets = targets.unsqueeze(2) 但后来我收到另一个错误,我必须使用 [-2,1] 中的值进行解压。我也尝试将损失函数修改

def loss_fn(outputs,targets):
 return torch.nn.bceloss()(outputs,targets)

但我仍然收到同样的错误。有人可以建议是否有解决此问题的方法吗?或者我该怎么做才能使这项工作正常进行?非常感谢。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。