微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

使用 Pytorch 的最简单的带有注意力的 LSTM编码器-解码器架构

如何解决使用 Pytorch 的最简单的带有注意力的 LSTM编码器-解码器架构

请帮助我理解如何使用编码器-解码器架构注意编写 LSTM (RNN)。我在 YouTube 上观看了很多视频,阅读了一些关于comingdatascience.com 的文章等等,但这些例子对我来说非常复杂。我需要最简单的 RNN 示例,它可以执行我之前所说的操作。

无论如何,我尝试自己编写它(并且有效!)但损失一直很大,而且几乎没有变化。我可以向您解释我认为它应该如何工作(并且我将描述我在代码中所做的每一步)但请您纠正我(我的意思是说,我的错误在哪里)并根据我的代码制作(如果可能)关注编码器-解码器架构的可行 LSTM。我会很高兴为我获得这些重要的知识。走吧!

首先,这是完整的代码

import torch.nn as nn
import torch.optim as optim
import torch
import numpy as np

class Encoder(nn.Module):
  def __init__(self):
    super(Encoder,self).__init__()
    self.lstm = nn.LSTM(1,1,batch_first=True)

  def forward(self,input):
    output,hidden_state = self.lstm(input)
    #print('DATA:',self.lstm.weight_hh_l0[0])
    return output

class Decoder(nn.Module):
  def __init__(self):
    super(Decoder,self).__init__()
    self.lstm = nn.LSTM(10,batch_first=True)
    self.linear = nn.ModuleList([nn.Linear(10,10) for i in range(10)])
    self.softmax = nn.softmax(dim=0)

  def forward(self,input):
    input = input.view(1,10)
    initial_input = input.clone()
    for i in range(10):
      if i == 0:
        input = self.linear[i](input)
        #print(input)
      else:
        input = torch.cat((input,self.linear[i](initial_input)),0)
        #print(input)
    input = input.view(-1,10,10)
    output,hidden_state = self.lstm(input)
    output = output.view(1,10)
    return output

class RNN(nn.Module):
  def __init__(self):
    super(RNN,self).__init__()

    self.encoder = Encoder()
    self.decoder = Decoder()
    print(self.encoder)
    print(self.decoder)

    self.loss = nn.CrossEntropyLoss()
    self.encoder_optimizer = optim.Adagrad(self.encoder.parameters(),lr=0.006)
    self.decoder_optimizer = optim.Adagrad(self.decoder.parameters(),lr=0.006)


  def train(self,input,target,test):

    # Encoder
    output_encoder = self.encoder(input)

    # Decoder
    output_decoder = self.decoder(output_encoder)

    if test:
      print(np.argmax(torch.nn.functional.softmax(output_decoder).detach().numpy().flatten()))

    self.encoder_optimizer.zero_grad()
    self.decoder_optimizer.zero_grad()

    loss = self.loss(output_decoder,target)
    print(loss)
    
    loss.backward()

    self.encoder_optimizer.step()
    self.decoder_optimizer.step()
    

    return 0

def main():
  rnn = RNN()

  input = [[10,1],[10,2],3],4],5],6],7],8],9],10]]
  labels = [0,2,3,4,5,6,7,8,9]

  
  labels = torch.tensor(labels).type(dtype=torch.long)

  losses = []
  for epoch in range(100):
    input_copy = input[epoch % 10]
    labels_copy = [labels[epoch % 10]]
    labels_copy = torch.tensor(labels_copy).type(dtype=torch.long)
    input_copy = torch.tensor(input_copy).float().requires_grad_(True).view(1,1)
    rnn.train(input_copy,labels_copy,False)
  
  rnn.train(input_copy,True)

main()

其次,这是我制作的图片以帮助您了解我所做的:

[点击此处查看图片][1]

另外,我应该说这个LSTM没有任何实际意义,我只是想知道如何编写这种类型的RNN。

让我们更深入地研究代码阅读评论,请),这里是编码器:

class Encoder(nn.Module):
  def __init__(self):
    super(Encoder,batch_first=True) #input is like [[1],[2],...,[10]],also,I'm planning to get WHOLE output from this nn.LSTM (not the last state),I mean the variable 'output' down below

  def forward(self,hidden_state = self.lstm(input) # just implement nn.LSTM
    #print('DATA:',self.lstm.weight_hh_l0[0])
    return output

这是解码器(对我来说最复杂的部分):

class Decoder(nn.Module):
  def __init__(self):
    super(Decoder,batch_first=True) #initialize RNN for the decoder
    self.linear = nn.ModuleList([nn.Linear(10,10) for i in range(10)]) #generating linear layer in order to kind of apply sub neural network that will somehow predict weights for important 'words' in our 'sentence' and I decided that those 10 linear layers (they were generated with different weights) should be applied to output from Encoder and each this linear layer should produce the 'input word' for the Decoder's LSTM (you can see what I said in the picture)
    self.softmax = nn.softmax(dim=0)

  def forward(self,input): #doing what I said before
    input = input.view(1,0) #merge all outputs from linear layers in one tensor,so,we've made the input 'sentence' for our Decoder's LSTM
        #print(input)
    input = input.view(-1,10)
    return output

在名为“RNN”的类中,我们只初始化我们的编码器和解码器,并为每个简单的 RNN 执行标准操作(无论如何,我不确定我是否以正确的顺序实现了所有这些,因为现在我在一个程序中有两个 RNN(如果只有一个 RNN,我确信我做的一切都是正确的))。

好吧,这是结果(这是损失):

tensor(2.4289,grad_fn=<NllLossBackward>)
tensor(2.3605,grad_fn=<NllLossBackward>)
tensor(2.4271,grad_fn=<NllLossBackward>)
tensor(2.4054,grad_fn=<NllLossBackward>)
tensor(2.4332,grad_fn=<NllLossBackward>)
tensor(2.2394,grad_fn=<NllLossBackward>)
tensor(2.2026,grad_fn=<NllLossBackward>)
tensor(2.0889,grad_fn=<NllLossBackward>)
tensor(2.3166,grad_fn=<NllLossBackward>)
tensor(2.3591,grad_fn=<NllLossBackward>)
tensor(2.3845,grad_fn=<NllLossBackward>)
tensor(2.3353,grad_fn=<NllLossBackward>)
tensor(2.4073,grad_fn=<NllLossBackward>)
tensor(2.3933,grad_fn=<NllLossBackward>)
tensor(2.4244,grad_fn=<NllLossBackward>)
tensor(2.2370,grad_fn=<NllLossBackward>)
tensor(2.2015,grad_fn=<NllLossBackward>)
tensor(2.3139,grad_fn=<NllLossBackward>)
tensor(2.3567,grad_fn=<NllLossBackward>)
tensor(2.3785,grad_fn=<NllLossBackward>)
tensor(2.3313,grad_fn=<NllLossBackward>)
tensor(2.4027,grad_fn=<NllLossBackward>)
tensor(2.3904,grad_fn=<NllLossBackward>)
tensor(2.4217,grad_fn=<NllLossBackward>)
tensor(2.2371,grad_fn=<NllLossBackward>)
tensor(2.2024,grad_fn=<NllLossBackward>)
tensor(2.0907,grad_fn=<NllLossBackward>)
tensor(2.3134,grad_fn=<NllLossBackward>)
tensor(2.3556,grad_fn=<NllLossBackward>)
tensor(2.3751,grad_fn=<NllLossBackward>)
tensor(2.3290,grad_fn=<NllLossBackward>)
tensor(2.3998,grad_fn=<NllLossBackward>)
tensor(2.3885,grad_fn=<NllLossBackward>)
tensor(2.4197,grad_fn=<NllLossBackward>)
tensor(2.2374,grad_fn=<NllLossBackward>)
tensor(2.2034,grad_fn=<NllLossBackward>)
tensor(2.0925,grad_fn=<NllLossBackward>)
tensor(2.3135,grad_fn=<NllLossBackward>)
tensor(2.3549,grad_fn=<NllLossBackward>)
tensor(2.3727,grad_fn=<NllLossBackward>)
tensor(2.3274,grad_fn=<NllLossBackward>)
tensor(2.3975,grad_fn=<NllLossBackward>)
tensor(2.3870,grad_fn=<NllLossBackward>)
tensor(2.4181,grad_fn=<NllLossBackward>)
tensor(2.2376,grad_fn=<NllLossBackward>)
tensor(2.2043,grad_fn=<NllLossBackward>)
tensor(2.0942,grad_fn=<NllLossBackward>)
tensor(2.3136,grad_fn=<NllLossBackward>)
tensor(2.3545,grad_fn=<NllLossBackward>)
tensor(2.3709,grad_fn=<NllLossBackward>)
tensor(2.3261,grad_fn=<NllLossBackward>)
tensor(2.3956,grad_fn=<NllLossBackward>)
tensor(2.3857,grad_fn=<NllLossBackward>)
tensor(2.4168,grad_fn=<NllLossBackward>)
tensor(2.2378,grad_fn=<NllLossBackward>)
tensor(2.2052,grad_fn=<NllLossBackward>)
tensor(2.0957,grad_fn=<NllLossBackward>)
tensor(2.3541,grad_fn=<NllLossBackward>)
tensor(2.3693,grad_fn=<NllLossBackward>)
tensor(2.3250,grad_fn=<NllLossBackward>)
tensor(2.3940,grad_fn=<NllLossBackward>)
tensor(2.3846,grad_fn=<NllLossBackward>)
tensor(2.4155,grad_fn=<NllLossBackward>)
tensor(2.2380,grad_fn=<NllLossBackward>)
tensor(2.2060,grad_fn=<NllLossBackward>)
tensor(2.0972,grad_fn=<NllLossBackward>)
tensor(2.3141,grad_fn=<NllLossBackward>)
tensor(2.3539,grad_fn=<NllLossBackward>)
tensor(2.3679,grad_fn=<NllLossBackward>)
tensor(2.3240,grad_fn=<NllLossBackward>)
tensor(2.3925,grad_fn=<NllLossBackward>)
tensor(2.3836,grad_fn=<NllLossBackward>)
tensor(2.4144,grad_fn=<NllLossBackward>)
tensor(2.2382,grad_fn=<NllLossBackward>)
tensor(2.2068,grad_fn=<NllLossBackward>)
tensor(2.0985,grad_fn=<NllLossBackward>)
tensor(2.3144,grad_fn=<NllLossBackward>)
tensor(2.3537,grad_fn=<NllLossBackward>)
tensor(2.3667,grad_fn=<NllLossBackward>)
tensor(2.3232,grad_fn=<NllLossBackward>)
tensor(2.3912,grad_fn=<NllLossBackward>)
tensor(2.3828,grad_fn=<NllLossBackward>)
tensor(2.4134,grad_fn=<NllLossBackward>)
tensor(2.2384,grad_fn=<NllLossBackward>)
tensor(2.2075,grad_fn=<NllLossBackward>)
tensor(2.0998,grad_fn=<NllLossBackward>)
tensor(2.3147,grad_fn=<NllLossBackward>)
tensor(2.3535,grad_fn=<NllLossBackward>)
tensor(2.3656,grad_fn=<NllLossBackward>)
tensor(2.3224,grad_fn=<NllLossBackward>)
tensor(2.3900,grad_fn=<NllLossBackward>)
tensor(2.3819,grad_fn=<NllLossBackward>)
tensor(2.4125,grad_fn=<NllLossBackward>)
tensor(2.2386,grad_fn=<NllLossBackward>)
tensor(2.2082,grad_fn=<NllLossBackward>)
tensor(2.1009,grad_fn=<NllLossBackward>)
tensor(2.3150,grad_fn=<NllLossBackward>)
tensor(2.3534,grad_fn=<NllLossBackward>)

好吧,你能帮我吗?这个 RNN 应该预测哪个序列(比如 [10,1])属于哪个类。

如果可能,请更改我的代码,以便我可以查看可行的示例并提高我的理解。我也知道需要注意 LSTM 需要处理非常大的“sequence_length”,但我只想了解 suce 架构的概念。或者你能写给我新的最简单的例如?好吧,任何帮助将不胜感激。 [1]:https://i.stack.imgur.com/D2V0S.jpg

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。