如何解决ValueError:train_dataset 没有实现 __len__,必须指定 max_steps
我有一个 TFRecord 数据集,我想用它来训练 XLNET 模型。这是我写的:
import tensorflow as tf
train_dataset = tf.data.TFRecordDataset(
'data.tfrecord')
import tensorflow as tf
import tensorflow_datasets as tfds
import os
import time
from transformers import XLNetConfig,XLNetModel
from transformers import Trainer,TrainingArguments
# Initializing an XLNet configuration
configuration = XLNetConfig(use_mems_train = True)
model = XLNetModel(configuration)
print(configuration)
training_args = TrainingArguments(
output_dir='./results',# output directory
num_train_epochs=3,# total # of training epochs
per_device_train_batch_size=16,# batch size per device during training
warmup_steps=500,# number of warmup steps for learning rate scheduler
weight_decay=0.01,# strength of weight decay
per_device_eval_batch_size = 64,logging_dir='./logs',# directory for storing logs
)
trainer = Trainer(
model=model,# the instantiated Transformers model to be trained
args=training_args,# training arguments,defined above
train_dataset=train_dataset,# training dataset
)
trainer.train()
Traceback (most recent call last):
File "dfgd.py",line 29,in <module>
train_dataset=train_dataset,# training dataset
File "C:\Users\DSP\AppData\Roaming\Python\python37\site-packages\transformers\trainer.py",line 313,in __init__
raise ValueError("train_dataset does not implement __len__,max_steps has to be specified")
ValueError: train_dataset does not implement __len__,max_steps has to be specified
# Enforce rules on using datasets with no __len__
if train_dataset is not None and not isinstance(train_dataset,collections.abc.Sized) and args.max_steps <= 0:
raise ValueError("train_dataset does not implement __len__,max_steps has to be specified")
if eval_dataset is not None and not isinstance(eval_dataset,collections.abc.Sized):
raise ValueError("eval_dataset must implement __len__")
我该如何解决这个问题?
解决方法
我在 Huggingface 的文档中看到了 Trainer() 和 TFTrainer()。
文档说 Trainer 接受 torch.utils.data.dataset.Dataset,但没有提及 tf.data.Dataset。
当您使用 tf.data.Dataset 时,您可能想尝试使用 TFTrainer 而不是 Trainer。
尝试实现something like:
import torch
class IMDbDataset(torch.utils.data.Dataset):
def __init__(self,encodings,labels):
self.encodings = encodings
self.labels = labels
def __getitem__(self,idx):
item = {key: torch.tensor(val[idx]) for key,val in self.encodings.items()}
item['labels'] = torch.tensor(self.labels[idx])
return item
def __len__(self):
return len(self.labels)
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。