微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

PyTorch:“类型错误:在 DataLoader 工作进程 0 中捕获到类型错误”

如何解决PyTorch:“类型错误:在 DataLoader 工作进程 0 中捕获到类型错误”

我正在尝试实施 RoBERTa 模型进行情感分析。首先,我声明了 GPReviewDataset 来创建 PyTorch 数据集。

MAX_LEN = 160
class GPReviewDataset(Dataset):
  def __init__(self,reviews,targets,tokenizer,max_len):
    self.reviews = reviews
    self.targets = targets
    self.tokenizer = tokenizer
    self.max_len = max_len
  def __len__(self):
    return len(self.reviews)
  def __getitem__(self,item):
    review = str(self.reviews[item])
    target = self.targets[item]
    encoding = self.tokenizer.encode_plus(
      review,add_special_tokens=True,max_length=self.max_len,return_token_type_ids=False,pad_to_max_length=True,return_attention_mask=True,return_tensors='pt',)
    return {
      'review_text': review,'input_ids': encoding['input_ids'].flatten(),'attention_mask': encoding['attention_mask'].flatten(),'targets': torch.tensor(target,dtype=torch.long)
    }

接下来,我实现 create_data_loader 来创建几个数据加载器。这是一个帮助函数来做到这一点:

def create_data_loader(df,max_len,batch_size):
  ds = GPReviewDataset(
    reviews=df.text.to_numpy(),targets=df.sentiment.to_numpy(),tokenizer=tokenizer,max_len=max_len
  )
  return DataLoader(
    ds,batch_size=batch_size,num_workers=4
  )
BATCH_SIZE = 16
train_data_loader = create_data_loader(df_train,MAX_LEN,BATCH_SIZE)
val_data_loader = create_data_loader(df_val,BATCH_SIZE)
test_data_loader = create_data_loader(df_test,BATCH_SIZE)
dt = next(iter(train_data_loader))

但是,当我运行此代码时,它会停止并给出这些错误

TypeError                                 Traceback (most recent call last)
<ipython-input-35-a673c0794f60> in <module>()
----> 1 dt = next(iter(train_data_loader))

3 frames
/usr/local/lib/python3.6/dist-packages/torch/utils/data/DataLoader.py in __next__(self)
    433         if self._sampler_iter is None:
    434             self._reset()
--> 435         data = self._next_data()
    436         self._num_yielded += 1
    437         if self._dataset_kind == _DatasetKind.Iterable and \

/usr/local/lib/python3.6/dist-packages/torch/utils/data/DataLoader.py in _next_data(self)
   1083             else:
   1084                 del self._task_info[idx]
-> 1085                 return self._process_data(data)
   1086 
   1087     def _try_put_index(self):

/usr/local/lib/python3.6/dist-packages/torch/utils/data/DataLoader.py in _process_data(self,data)
   1109         self._try_put_index()
   1110         if isinstance(data,ExceptionWrapper):
-> 1111             data.reraise()
   1112         return data
   1113 

/usr/local/lib/python3.6/dist-packages/torch/_utils.py in reraise(self)
    426             # have message field
    427             raise self.exc_type(message=msg)
--> 428         raise self.exc_type(msg)
    429 
    430 

TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py",line 198,in _worker_loop
    data = fetcher.fetch(index)
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py",line 44,in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py",in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "<ipython-input-18-1e537ce5a428>",line 25,in __getitem__
    'targets': torch.tensor(target,dtype=torch.long)
TypeError: new(): invalid data type 'str'

我不明白为什么会这样,谁能帮我解释一下。

解决方法

您需要将类定义为整数。我假设您正在处理分类问题。但看起来您已将类定义为字符串。您需要将类从字符串转换为整数。比如df.sentiment对应正数,必须用0表示,df.sentiment对应负数,需要在新列中用1表示。

def to_int_sentiment(label):
  if label == "positive":
    return 0
  elif label == "negative":
    return 1

df['int_sentiment'] = df.sentiment.apply(to_int_sentiment)

那么您应该使用列 df.int_sentiment 而不是 df.sentiment。所以你必须改变 create_data_loader 函数如下。

def create_data_loader(df,tokenizer,max_len,batch_size):
  ds = GPReviewDataset(
    reviews=df.text.to_numpy(),targets=df.int_sentiment.to_numpy(),tokenizer=tokenizer,max_len=max_len
  )
  return DataLoader(
    ds,batch_size=batch_size,num_workers=4
  )

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。