微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

生成器“ TypeError:'generator'对象不是迭代器”

如何解决生成器“ TypeError:'generator'对象不是迭代器”

我遇到了同样的问题,我设法通过定义一个__next__方法解决了这个问题:

class My_Generator(Sequence):
    def __init__(self, image_filenames, labels, batch_size):
        self.image_filenames, self.labels = image_filenames, labels
        self.batch_size = batch_size
        self.n = 0
        self.max = self.__len__()


    def __len__(self):
        return np.ceil(len(self.image_filenames) / float(self.batch_size))

    def __getitem__(self, idx):
        batch_x = self.image_filenames[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_y = self.labels[idx * self.batch_size:(idx + 1) * self.batch_size]

        return np.array([
        resize(imread(file_name), (200, 200))
           for file_name in batch_x]), np.array(batch_y)

    def __next__(self):
        if self.n >= self.max:
           self.n = 0
        result = self.__getitem__(self.n)
        self.n += 1
        return result

请注意,我在__init__函数中声明了两个新变量。

解决方法

由于RAM内存的限制,我遵循了这些指令,并构建了一个生成器,该生成器可以绘制小批量并将其传递给Keras的fit_generator。但是,即使我继承了Sequence,Keras也无法使用多重处理来准备队列。

这是我的多处理生成器。

class My_Generator(Sequence):
    def __init__(self,image_filenames,labels,batch_size):
        self.image_filenames,self.labels = image_filenames,labels
        self.batch_size = batch_size

    def __len__(self):
        return np.ceil(len(self.image_filenames) / float(self.batch_size))

    def __getitem__(self,idx):
        batch_x = self.image_filenames[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_y = self.labels[idx * self.batch_size:(idx + 1) * self.batch_size]

    return np.array([
        resize(imread(file_name),(200,200))
           for file_name in batch_x]),np.array(batch_y)

主要功能:

batch_size = 100
num_epochs = 10
train_fnames = []
mask_training = []
val_fnames = [] 
mask_validation = []

我希望生成器按ID分别在不同线程中读取文件夹中的批处理(其中ID看起来像:{number} .csv用于原始图像,{number}
_label.csv用于掩码图像)。最初,我建立了另一个更优雅的类,将每个数据存储在一个.h5文件而不是目录中。但是阻止了同样的问题。因此,如果您有执行此操作的代码,那么我也是。

for dirpath,_,fnames in os.walk('./train/'):
    for fname in fnames:
        if 'label' not in fname:
            training_filenames.append(os.path.abspath(os.path.join(dirpath,fname)))
        else:
            mask_training.append(os.path.abspath(os.path.join(dirpath,fname)))
for dirpath,fnames in os.walk('./validation/'):
    for fname in fnames:
        if 'label' not in fname:
            validation_filenames.append(os.path.abspath(os.path.join(dirpath,fname)))
        else:
            mask_validation.append(os.path.abspath(os.path.join(dirpath,fname)))


my_training_batch_generator = My_Generator(training_filenames,mask_training,batch_size)
my_validation_batch_generator = My_Generator(validation_filenames,mask_validation,batch_size)
num_training_samples = len(training_filenames)
num_validation_samples = len(validation_filenames)

在此,该模型不在范围内。我相信这不是模型的问题,所以我不会粘贴它。

mdl = model.compile(...)
mdl.fit_generator(generator=my_training_batch_generator,steps_per_epoch=(num_training_samples // batch_size),epochs=num_epochs,verbose=1,validation_data=None,#my_validation_batch_generator,# validation_steps=(num_validation_samples // batch_size),use_multiprocessing=True,workers=4,max_queue_size=2)

该错误表明我创建的类不是Iterator:

Traceback (most recent call last):
File "test.py",line 141,in <module> max_queue_size=2)
File "/anaconda3/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py",line 2177,in fit_generator
initial_epoch=initial_epoch)
File "/anaconda3/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_generator.py",line 147,in fit_generator
generator_output = next(output_generator)
File "/anaconda3/lib/python3.6/site-packages/tensorflow/python/keras/utils/data_utils.py",line 831,in get six.reraise(value.__class__,value,value.__traceback__)
File "/anaconda3/lib/python3.6/site-packages/six.py",line 693,in reraise
raise value
TypeError: 'My_Generator' object is not an iterator

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。