微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

ValueError:train_dataset 没有实现 __len__,必须指定 max_steps

如何解决ValueError:train_dataset 没有实现 __len__,必须指定 max_steps

我有一个 TFRecord 数据集,我想用它来训练 XLNET 模型。这是我写的:

import tensorflow as tf
train_dataset = tf.data.TFRecordDataset(
    'data.tfrecord')
import tensorflow as tf

import tensorflow_datasets as tfds
import os
import time
from transformers import XLNetConfig,XLNetModel
from transformers import Trainer,TrainingArguments
# Initializing an XLNet configuration
configuration = XLNetConfig(use_mems_train = True)
model = XLNetModel(configuration)
print(configuration)
training_args = TrainingArguments(
    output_dir='./results',# output directory
    num_train_epochs=3,# total # of training epochs
    per_device_train_batch_size=16,# batch size per device during training
    warmup_steps=500,# number of warmup steps for learning rate scheduler
    weight_decay=0.01,# strength of weight decay
    per_device_eval_batch_size = 64,logging_dir='./logs',# directory for storing logs
)
trainer = Trainer(
    model=model,# the instantiated Transformers model to be trained
    args=training_args,# training arguments,defined above
    train_dataset=train_dataset,# training dataset
)
trainer.train()

但是,在运行上面的代码后,我收到以下错误

Traceback (most recent call last):
  File "dfgd.py",line 29,in <module>
    train_dataset=train_dataset,# training dataset
  File "C:\Users\DSP\AppData\Roaming\Python\python37\site-packages\transformers\trainer.py",line 313,in __init__
    raise ValueError("train_dataset does not implement __len__,max_steps has to be specified")
ValueError: train_dataset does not implement __len__,max_steps has to be specified

这是引发此错误的源代码的一部分:

# Enforce rules on using datasets with no __len__
if train_dataset is not None and not isinstance(train_dataset,collections.abc.Sized) and args.max_steps <= 0:
    raise ValueError("train_dataset does not implement __len__,max_steps has to be specified")
if eval_dataset is not None and not isinstance(eval_dataset,collections.abc.Sized):
    raise ValueError("eval_dataset must implement __len__")

我该如何解决这个问题?

解决方法

我在 Huggingface 的文档中看到了 Trainer()TFTrainer()

文档说 Trainer 接受 torch.utils.data.dataset.Dataset,但没有提及 tf.data.Dataset。

当您使用 tf.data.Dataset 时,您可能想尝试使用 TFTrainer 而不是 Trainer。

尝试实现something like

import torch

class IMDbDataset(torch.utils.data.Dataset):
  def __init__(self,encodings,labels):
      self.encodings = encodings
      self.labels = labels

  def __getitem__(self,idx):
      item = {key: torch.tensor(val[idx]) for key,val in self.encodings.items()}
      item['labels'] = torch.tensor(self.labels[idx])
      return item

  def __len__(self):
      return len(self.labels)

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。