微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

尝试使用生成器训练神经网络时,发生元组索引超出范围错误

如何解决尝试使用生成器训练神经网络时,发生元组索引超出范围错误

我正在使用序列到序列模型研究文本生成,但遇到了无法解决的问题。描述:我根据问题和答案训练一个神经网络。当我准备解码数据时,我得到以下张量维度:

tokenizedAnswers = tokenizer.texts_to_sequences(answers) ### split the text responses to the sequence of indexes

for i in range(len(tokenizedAnswers)) :
tokenizedAnswers[i] = tokenizedAnswers[i][1:] #### for responses divided into sequences,we get rid of the START tag

paddedAnswers = pad_sequences(tokenizedAnswers,maxlen=maxLenAnswers,padding= 'post') # # # Making sequences of the same length,filling in shorter responses with zeros

decoderForOutput = utils.to_categorical(paddedAnswers,vocabularySize) # # # converting to one hot vector

在这个阶段,paddedAnswers 变量包含一个大小为 (4594,25) 的二维 numpy 张量,对应于我的基数,decoderForOutput 变量包含一个三维 numpy 张量,大小为 (4594,25,10785)。这也匹配我的数据库

paddedAnswers.shape   ### (4594,25)
decoderForOutput.shape   ### (4594,10785)

这里,为了节省内存,我想使用Python生成器。为此,我创建了两个生成函数

def generator_from_two_dimensional_tensor(arg):
  for i in range(arg.shape[0]):
    for j in range(arg.shape[1]):
      yield arg[i,j]

def generator_from_three_dimensional_tensor(arg):
  for i in range(arg.shape[0]):
    for j in range(arg.shape[1]):
      for k in range(arg.shape[2]):
        yield arg[i,j,k]

我在代码中替换了这些函数,如下所示。我还从标记化的响应中制作了一个生成器:

tokenizedAnswer = tokenizer.texts_to_sequences(answers)
for i in range(len(tokenizedAnswer)):
  tokenizedAnswer[i] = tokenizedAnswer[i][1:]
  
generator_tokenized_answers = (x for x in tokenizedAnswer)
gen_paddedAnswers = generator_from_two_dimensional_tensor(pad_sequences([x for x in generator_tokenized_answers],padding='post'))

decoderForOutput = generator_from_three_dimensional_tensor(utils.to_categorical([x for x in gen_paddedAnswers],vocabularySize))

代码被触发,它没有给出任何错误。但是当我尝试训练网络时,出现错误

history = model.fit([[x for x in gen_encoderForInput],[y for y in gen_decoderForInput]],[z for z in decoderForOutput],batch_size=50,epochs=20)

<ipython-input-6-a3a67c64deaf> in generator_three_tensor(arg)
      2   for i in range(arg.shape[0]):
      3     for j in range(arg.shape[1]):
----> 4       for k in range(arg.shape[2]):
      5         yield arg[i,k]
      6 

IndexError: tuple index out of range

为什么会发生这种情况对我来说仍然是个谜...

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。