微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

损失振荡而不是减少 seq2seq gru pytorch

如何解决损失振荡而不是减少 seq2seq gru pytorch

我正在尝试构建 seq2seq 模型,并关注 Pytorch gru 以进行情感提取。 情感提取示例: 输入 1:“太伤心了,我会在圣地亚哥想念你!!!” 目标 1:“太伤心了” 输入 2:'我的老板在欺负我......' 目标 2:'欺负我'

我的模型:

class Encoder(nn.Module):
    def __init__(self,vocab_size,embedding_size,hidden_size,dropout_odds=0.2):
        super(Encoder,self).__init__()
        self.dropout = nn.Dropout(dropout_odds)
        self.embedding = nn.Embedding(vocab_size,embedding_size)
        self.gru = nn.GRU(embedding_size,bidirectional=True)
        self.fc_hidden = nn.Linear(hidden_size*2,hidden_size)
        
    def forward(self,x):
        embedded = self.dropout(self.embedding(x)).view(1,1,-1)
        output,hidden = self.gru(embedded)
        hidden = self.fc_hidden(torch.cat((hidden[0:1],hidden[1:2]),dim=2))
        
        return output,hidden

class Decoder(nn.Module):
    def __init__(self,dropout_odds=0.2):
        super(Decoder,self).__init__()
        
        self.dropout = nn.Dropout(dropout_odds)
        self.embedding = nn.Embedding(vocab_size,embedding_size)
        self.energy = nn.Linear(hidden_size*3,1)
        self.gru = nn.GRU(hidden_size*2+embedding_size,hidden_size)
        self.out = nn.Linear(hidden_size,vocab_size)
        self.relu = nn.ReLU()
        self.softmax_en = nn.softmax(dim=0)
        self.softmax_out = nn.softmax(dim=2)

    def forward(self,x,hidden,encoder_outputs):
        embedded = self.dropout(self.embedding(x)).view(1,-1)
        sequence_length = encoder_outputs.shape[0]
        h_reshaped = hidden[0].repeat(sequence_length,1)
        energy_ = self.relu(self.energy(torch.cat((h_reshaped,encoder_outputs),dim=1)))
        attention = self.softmax_en(energy_)
        context_vector = torch.einsum("st,sh->th",attention,encoder_outputs)
        output = torch.cat((context_vector.unsqueeze(0),embedded),dim=2)
        output,hidden = self.gru(output,hidden)
        output = self.out(output)
        return output,hidden

class Seq2seq(nn.Module):
    def __init__(self,encoder,decoder):
        super(Seq2seq,self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        
    def forward(self,source,target,teacher_force_ratio=0.5):
        sourse_len = source.shape[0]
        target_len = target.shape[0]
        target_vocab_size = vocab.n_words
        encoder_outputs =  torch.zeros(sourse_len,hidden_size*2,device=device)
        decoder_outputs = torch.zeros(target_len,target_vocab_size,device=device)
        for ei in range(sourse_len):
            encoder_states,hidden = self.encoder(source[ei])
            encoder_outputs[ei] = encoder_states
        x = target[0]

        for t in range(1,target_len):
            output,hidden = self.decoder(x,encoder_outputs)
            decoder_outputs[t] = output
            best_guess = output.argmax()
            x = best_guess

        return decoder_outputs

hidden_size = 1024
embedding_size = 1024
epochs=100
encoder1 = Encoder(vocab.n_words,hidden_size).to(device)
attn_decoder1 = Decoder(vocab.n_words,hidden_size).to(device)
model = Seq2seq(encoder1,attn_decoder1).to(device)
# optimizer = optim.SGD(model.parameters(),lr=0.05)
optimizer = optim.Adam(model.parameters(),lr=0.1)
criterion = nn.CrossEntropyLoss()
training_pairs = [tensorsFromPair(pair) for pair in pairs]

model.train()
for epoch in range(1,epochs+1):
    print(f"[Epoch {epoch} / {epochs}]")
    
    indx = 0
    for tp in training_pairs[:100]:
        
        inp_data = tp[0]
        target = tp[1]
        
        output = model(inp_data,target)
        optimizer.zero_grad()
        loss = criterion(output.squeeze(1),target)
        loss.backward()
        optimizer.step()
        indx+=1
        if indx%100==0:
            print('loss: ',loss)        

训练对是张量列表,每个张量包含一对张量,每个元素是词汇中单词的索引:

training_pairs[0] = (tensor([18,19,20,21,22,1]),tensor([21,1])),where 1 EOS_token

由于某种原因,当模型学习时损失并没有减少,而是振荡:

[Epoch 1 / 100]
loss:  tensor(974.1525,device='cuda:0',grad_fn=<NllLossBackward>)
[Epoch 2 / 100]
loss:  tensor(5.4103,grad_fn=<NllLossBackward>)
[Epoch 3 / 100]
loss:  tensor(3404.2073,grad_fn=<NllLossBackward>)
[Epoch 4 / 100]
loss:  tensor(5.4103,grad_fn=<NllLossBackward>)
[Epoch 5 / 100]
loss:  tensor(2543.9885,grad_fn=<NllLossBackward>)
[Epoch 6 / 100]
loss:  tensor(28.9650,grad_fn=<NllLossBackward>)
[Epoch 7 / 100]
loss:  tensor(2998.9417,grad_fn=<NllLossBackward>)
[Epoch 8 / 100]
loss:  tensor(5.4103,grad_fn=<NllLossBackward>)
[Epoch 9 / 100]
loss:  tensor(5.4103,grad_fn=<NllLossBackward>)
[Epoch 10 / 100]
loss:  tensor(5.4103,grad_fn=<NllLossBackward>)

如果您更改 SOS 令牌的编码器输入,情况会重复。 我还尝试在没有 EOS 代币的情况下计算损失,它不会改变任何东西。 我不明白为什么损失计算不正确。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。